aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-08-16 14:44:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 14:47:49 -0700
commitd43820b9eff0cc863de2bbfb142afe92bf5afd00 (patch)
treee019bdc0bb52496436e0ff92d76b728165e3a420
parent8235e83c442744a1285ce97e5dfc2a6556f9f667 (diff)
Improve gather ergonomics by renaming fields.
This CL renames the various inputs to the Gather HLO to be more mnemonic by making it more obviously a batch dynamic-slice. The replacements I made are: s/elided_window_dims/collapsed_slice_dims/g s/window_bounds/slice_sizes/g s/gather_dims_to_operand_dims/start_index_map/g s/gather_indices/start_indices/g s/output_window_dims/offset_dims/g PiperOrigin-RevId: 209051067
-rw-r--r--tensorflow/compiler/tf2xla/kernels/gather_op.cc32
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.cc14
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc26
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h12
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc56
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc11
-rw-r--r--tensorflow/compiler/xla/service/gather_expander.cc180
-rw-r--r--tensorflow/compiler/xla/service/gather_expander_test.cc16
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto2
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc161
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc112
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc24
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h8
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc66
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc57
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h16
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc35
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc2
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.cc43
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.h2
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis_test.cc300
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc201
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h4
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc269
-rw-r--r--tensorflow/compiler/xla/tests/gather_operation_test.cc312
-rw-r--r--tensorflow/compiler/xla/xla_data.proto18
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md267
28 files changed, 1103 insertions, 1153 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
index 35de96e0aa..44140304fd 100644
--- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
@@ -95,11 +95,11 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
// operand = s32[3,3] parameter(0)
// indices = s32[2] parameter(1)
// gather = s32[3,2] gather(operand, indices),
- // output_window_dims={0},
- // elided_window_dims={1},
- // gather_dims_to_operand_dims={1},
+ // offset_dims={0},
+ // collapsed_slice_dims={1},
+ // start_index_map={1},
// index_vector_dim=1,
- // window_bounds={3, 1}
+ // slice_sizes={3, 1}
//
//
// Example of an N-D gather pulling out slices of shape [1,1,2] out of a
@@ -108,42 +108,42 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
// operand = s32[3,3,2] parameter(0)
// indices = s32[2,2] parameter(1)
// gather = s32[2,2] gather(operand, indices),
- // output_window_dims={1},
- // elided_window_dims={0,1},
- // gather_dims_to_operand_dims={0,1},
+ // offset_dims={1},
+ // collapsed_slice_dims={0,1},
+ // start_index_map={0,1},
// index_vector_dim=0,
- // window_bounds={1,1,2}
+ // slice_sizes={1,1,2}
xla::GatherDimensionNumbers dim_numbers;
- std::vector<int64> window_bounds;
- window_bounds.reserve(input_shape.dims());
+ std::vector<int64> slice_sizes;
+ slice_sizes.reserve(input_shape.dims());
for (int64 i = 0; i < input_shape.dims(); i++) {
int64 window_bound;
if (axis <= i && i < (axis + num_index_dims)) {
- dim_numbers.add_elided_window_dims(i);
+ dim_numbers.add_collapsed_slice_dims(i);
window_bound = 1;
} else {
window_bound = input_shape.dim_size(i);
}
- window_bounds.push_back(window_bound);
+ slice_sizes.push_back(window_bound);
if (i < axis) {
- dim_numbers.add_output_window_dims(i);
+ dim_numbers.add_offset_dims(i);
} else if (i >= (axis + num_index_dims)) {
int64 indices_rank =
indices_are_nd ? (indices_shape.dims() - 1) : indices_shape.dims();
- dim_numbers.add_output_window_dims(i + indices_rank - num_index_dims);
+ dim_numbers.add_offset_dims(i + indices_rank - num_index_dims);
}
}
dim_numbers.set_index_vector_dim(indices_are_nd ? (indices_shape.dims() - 1)
: indices_shape.dims());
for (int64 i = axis; i < axis + num_index_dims; i++) {
- dim_numbers.add_gather_dims_to_operand_dims(i);
+ dim_numbers.add_start_index_map(i);
}
- *gather_output = xla::Gather(input, indices, dim_numbers, window_bounds);
+ *gather_output = xla::Gather(input, indices, dim_numbers, slice_sizes);
return Status::OK();
}
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
index 04fa10108c..febb638e5e 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
@@ -57,7 +57,7 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) {
// We can grab entire blocks using gather
if (n > block_size) {
// Construct the starting indices of the diagonal blocks
- auto gather_indices =
+ auto start_indices =
Transpose(Broadcast(Mul(Iota(builder, xla::S32, num_blocks),
xla::ConstantR0<int32>(builder, block_size)),
/*broadcast_sizes=*/{2}),
@@ -65,13 +65,13 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) {
// Gather the diagonal blocks
xla::GatherDimensionNumbers dim_numbers;
- dim_numbers.add_output_window_dims(ndims - 1);
- dim_numbers.add_output_window_dims(ndims);
- dim_numbers.add_gather_dims_to_operand_dims(ndims - 2);
- dim_numbers.add_gather_dims_to_operand_dims(ndims - 1);
+ dim_numbers.add_offset_dims(ndims - 1);
+ dim_numbers.add_offset_dims(ndims);
+ dim_numbers.add_start_index_map(ndims - 2);
+ dim_numbers.add_start_index_map(ndims - 1);
dim_numbers.set_index_vector_dim(1);
- diag_blocks = Gather(a, gather_indices, dim_numbers,
- /*window_bounds=*/{block_size, block_size});
+ diag_blocks = Gather(a, start_indices, dim_numbers,
+ /*slice_sizes=*/{block_size, block_size});
}
// The last block might be smaller than the block size,
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index aa47f992bc..4dffab3c2c 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -1631,27 +1631,27 @@ XlaOp XlaBuilder::While(const XlaComputation& condition,
});
}
-XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices,
+XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds) {
+ tensorflow::gtl::ArraySlice<int64> slice_sizes) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input));
- TF_ASSIGN_OR_RETURN(const Shape& gather_indices_shape,
- GetShape(gather_indices));
+ TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape,
+ GetShape(start_indices));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
- ShapeInference::InferGatherShape(input_shape, gather_indices_shape,
- dimension_numbers, window_bounds));
+ ShapeInference::InferGatherShape(input_shape, start_indices_shape,
+ dimension_numbers, slice_sizes));
*instr.mutable_gather_dimension_numbers() = dimension_numbers;
- for (int64 bound : window_bounds) {
- instr.add_gather_window_bounds(bound);
+ for (int64 bound : slice_sizes) {
+ instr.add_gather_slice_sizes(bound);
}
return AddInstruction(std::move(instr), HloOpcode::kGather,
- {input, gather_indices});
+ {input, start_indices});
});
}
@@ -2906,11 +2906,11 @@ XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
mantissa_bits);
}
-XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
+XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds) {
- return input.builder()->Gather(input, gather_indices, dimension_numbers,
- window_bounds);
+ tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ return input.builder()->Gather(input, start_indices, dimension_numbers,
+ slice_sizes);
}
XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 78aec770a6..469d5048b2 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -877,9 +877,9 @@ class XlaBuilder {
const int mantissa_bits);
// Enqueues a Gather node onto the computation.
- XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
+ XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
// Enqueues a Scatter node onto the computation.
XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
@@ -1328,9 +1328,9 @@ class XlaBuilder {
const XlaComputation& false_computation);
friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
const int mantissa_bits);
- friend XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
+ friend XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
friend XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
const XlaOp& updates,
const XlaComputation& update_computation,
@@ -2024,9 +2024,9 @@ XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
const int mantissa_bits);
// Enqueues a Gather node onto the computation.
-XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
+XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
// Enqueues a Scatter node onto the computation.
XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
index 8cbdc36f84..e6130c7d76 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
@@ -792,11 +792,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
gather = s32[3,2] gather(operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
one = s32[] constant(1)
one_broadcasted = s32[3,2] broadcast(one), dimensions={}
ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted)
@@ -808,11 +808,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,3,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
one = s32[] constant(1)
one_broadcasted = s32[2,3,2] broadcast(one), dimensions={}
ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted)
@@ -824,11 +824,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
- output_window_dims={},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=2,
- window_bounds={1, 1}
+ slice_sizes={1, 1}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
@@ -840,11 +840,11 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
@@ -856,11 +856,11 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
@@ -872,11 +872,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
gather = s32[1,1] gather(operand, indices),
- output_window_dims={0,1},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={0,1},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
one = s32[] constant(1)
one_broadcasted = s32[1,1] broadcast(one), dimensions={}
ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted)
@@ -888,11 +888,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,1,1] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
one = s32[] constant(1)
one_broadcasted = s32[2,1,1] broadcast(one), dimensions={}
ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted)
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 2e9d6be2de..891ae42141 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -1672,22 +1672,21 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
std::vector<int64> operand_to_output_dim(operand_shape.dimensions_size(), -1);
for (int64 i = 0, e = operand_shape.dimensions_size(), operand_index_dim = 0;
i < e; i++) {
- if (c_binary_search(dim_numbers.elided_window_dims(), i)) {
+ if (c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
operand_index.push_back(index.GetConstantWithIndexType(0));
} else {
- int64 output_window_dim =
- dim_numbers.output_window_dims(operand_index_dim++);
+ int64 output_window_dim = dim_numbers.offset_dims(operand_index_dim++);
operand_to_output_dim[i] = output_window_dim;
operand_index.push_back(index[output_window_dim]);
}
}
- // This is the index of the index vector in the gather_indices tensor.
+ // This is the index of the index vector in the start_indices tensor.
IrArray::Index gather_index_index(index_type);
{
std::vector<llvm::Value*> gather_index_index_components;
for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) {
- if (!c_binary_search(dim_numbers.output_window_dims(), i)) {
+ if (!c_binary_search(dim_numbers.offset_dims(), i)) {
gather_index_index.push_back(index[i]);
}
}
@@ -1700,7 +1699,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) {
llvm::Value* gather_dim_component_extended =
b_->CreateSExtOrTrunc(index_component, index_type);
- int64 operand_dim = dim_numbers.gather_dims_to_operand_dims(dim);
+ int64 operand_dim = dim_numbers.start_index_map(dim);
int64 output_dim = operand_to_output_dim[operand_dim];
// If 'output_dim' is -1, it means 'operand_dim' is an elided window dim.
// This means we set the iteration index to 0, so for the purpose of the
diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc
index e3a42d0d06..9370c88710 100644
--- a/tensorflow/compiler/xla/service/gather_expander.cc
+++ b/tensorflow/compiler/xla/service/gather_expander.cc
@@ -27,85 +27,85 @@ namespace xla {
using tensorflow::gtl::ArraySlice;
static StatusOr<HloInstruction*> TransposeIndexVectorDimToLast(
- HloInstruction* gather_indices, int64 index_vector_dim) {
- const Shape& gather_indices_shape = gather_indices->shape();
+ HloInstruction* start_indices, int64 index_vector_dim) {
+ const Shape& start_indices_shape = start_indices->shape();
- if (gather_indices_shape.dimensions_size() == index_vector_dim) {
- return gather_indices;
+ if (start_indices_shape.dimensions_size() == index_vector_dim) {
+ return start_indices;
}
- if (index_vector_dim == (gather_indices_shape.dimensions_size() - 1)) {
- return gather_indices;
+ if (index_vector_dim == (start_indices_shape.dimensions_size() - 1)) {
+ return start_indices;
}
std::vector<int64> permutation;
- permutation.reserve(gather_indices_shape.dimensions_size());
- for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) {
+ permutation.reserve(start_indices_shape.dimensions_size());
+ for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) {
if (i != index_vector_dim) {
permutation.push_back(i);
}
}
permutation.push_back(index_vector_dim);
- return MakeTransposeHlo(gather_indices, permutation);
+ return MakeTransposeHlo(start_indices, permutation);
}
-// Canonicalizes the gather_indices tensors so that we only have deal with some
+// Canonicalizes the start_indices tensors so that we only have deal with some
// specific cases in the while loop that does the heavy lifting.
//
// See the "High Level Algorithm" section for a broader picture.
static StatusOr<HloInstruction*> CanonicalizeGatherIndices(
- HloInstruction* gather_indices, int64 index_vector_dim) {
+ HloInstruction* start_indices, int64 index_vector_dim) {
// Transpose the non-index-vector dimensions to the front.
TF_ASSIGN_OR_RETURN(
- HloInstruction * transposed_gather_indices,
- TransposeIndexVectorDimToLast(gather_indices, index_vector_dim));
+ HloInstruction * transposed_start_indices,
+ TransposeIndexVectorDimToLast(start_indices, index_vector_dim));
bool indices_are_scalar =
- index_vector_dim == gather_indices->shape().dimensions_size();
+ index_vector_dim == start_indices->shape().dimensions_size();
- // The number of dimensions in gather_indices that are index dimensions.
- const int64 index_dims_in_gather_indices = indices_are_scalar ? 0 : 1;
+ // The number of dimensions in start_indices that are index dimensions.
+ const int64 index_dims_in_start_indices = indices_are_scalar ? 0 : 1;
- // If there is only one index (i.e. gather_indices has rank 1 and this gather
+ // If there is only one index (i.e. start_indices has rank 1 and this gather
// is really just a dynamic slice) add a leading degenerate dimension for
// uniformity. Otherwise create a "collapsed" leading dimension that subsumes
// all of the non-index-vector dimensions.
- const Shape& shape = transposed_gather_indices->shape();
- if (shape.dimensions_size() == index_dims_in_gather_indices) {
- return PrependDegenerateDims(transposed_gather_indices, 1);
+ const Shape& shape = transposed_start_indices->shape();
+ if (shape.dimensions_size() == index_dims_in_start_indices) {
+ return PrependDegenerateDims(transposed_start_indices, 1);
} else {
- // Collapse all but the dimensions (0 or 1) in gather_indices containing the
+ // Collapse all but the dimensions (0 or 1) in start_indices containing the
// index vectors.
return CollapseFirstNDims(
- transposed_gather_indices,
- shape.dimensions_size() - index_dims_in_gather_indices);
+ transposed_start_indices,
+ shape.dimensions_size() - index_dims_in_start_indices);
}
}
// Expands out or contracts away the gather dimensions in the accumulator
// produced by the while loop.
-static StatusOr<HloInstruction*> AdjustGatherDimsInAccumulator(
- const Shape& gather_indices_shape, HloInstruction* accumulator,
+static StatusOr<HloInstruction*> AdjustBatchDimsInAccumulator(
+ const Shape& start_indices_shape, HloInstruction* accumulator,
int64 index_vector_dim) {
- std::vector<int64> output_gather_dim_bounds;
- output_gather_dim_bounds.reserve(gather_indices_shape.dimensions_size());
- for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) {
+ std::vector<int64> batch_dim_bounds;
+ batch_dim_bounds.reserve(start_indices_shape.dimensions_size());
+ for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) {
if (i != index_vector_dim) {
- output_gather_dim_bounds.push_back(gather_indices_shape.dimensions(i));
+ batch_dim_bounds.push_back(start_indices_shape.dimensions(i));
}
}
- if (output_gather_dim_bounds.empty()) {
- // If output_gather_dim_bounds is empty we must be lowering a (effectively)
+ if (batch_dim_bounds.empty()) {
+ // If batch_dim_bounds is empty we must be lowering a (effectively)
// dynamic-slice. In that case, there is a leading degenerate gather
// dimension that we added to make this special case play well with the
// general while loop which we need to remove now.
return ElideDegenerateDims(accumulator, {0});
}
- return ExpandFirstDimIntoNDims(accumulator, output_gather_dim_bounds);
+ return ExpandFirstDimIntoNDims(accumulator, batch_dim_bounds);
}
-// Expand an index vector from the gather_indices tensor into a vector that can
+// Expand an index vector from the start_indices tensor into a vector that can
// be used to dynamic-slice out of the gather operand.
static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
HloInstruction* index_vector, const GatherDimensionNumbers& dim_numbers,
@@ -121,10 +121,8 @@ static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
std::vector<HloInstruction*> expanded_index_components;
for (int i = 0; i < operand_rank; i++) {
- int64 index_vector_dim_index =
- FindIndex(dim_numbers.gather_dims_to_operand_dims(), i);
- if (index_vector_dim_index !=
- dim_numbers.gather_dims_to_operand_dims_size()) {
+ int64 index_vector_dim_index = FindIndex(dim_numbers.start_index_map(), i);
+ if (index_vector_dim_index != dim_numbers.start_index_map_size()) {
TF_ASSIGN_OR_RETURN(
HloInstruction * component_to_concat,
MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index},
@@ -147,10 +145,10 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
const GatherDimensionNumbers& dim_numbers = gather.gather_dimension_numbers();
CHECK_EQ(incoming_loop_state.size(), 3);
HloInstruction* const operand = incoming_loop_state[0];
- HloInstruction* const gather_indices = incoming_loop_state[1];
+ HloInstruction* const start_indices = incoming_loop_state[1];
HloInstruction* const output_accumulator = incoming_loop_state[2];
- bool has_scalar_indices = gather_indices->shape().dimensions_size() == 1;
+ bool has_scalar_indices = start_indices->shape().dimensions_size() == 1;
CHECK_EQ(has_scalar_indices,
dim_numbers.index_vector_dim() ==
gather.operand(1)->shape().dimensions_size());
@@ -163,24 +161,24 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
HloInstruction* index_vector;
if (has_scalar_indices) {
- // In this case gather_indices has rank 1 and induction_var_as_vector (of
+ // In this case start_indices has rank 1 and induction_var_as_vector (of
// shape {1}) is an index into this rank 1 tensor.
TF_ASSIGN_OR_RETURN(
index_vector,
- MakeDynamicSliceHlo(gather_indices, induction_var_as_vector, {1}));
+ MakeDynamicSliceHlo(start_indices, induction_var_as_vector, {1}));
} else {
- // In this case gather_indices has rank 2 and induction_var_as_vector (of
+ // In this case start_indices has rank 2 and induction_var_as_vector (of
// shape {1}) is an index into just the first dimension of this rank 2
// tensor.
TF_ASSIGN_OR_RETURN(
- HloInstruction * index_into_gather_indices,
+ HloInstruction * index_into_start_indices,
PadVectorWithZeros(induction_var_as_vector,
/*zeros_to_prepend=*/0, /*zeros_to_append=*/1));
- int64 index_vector_size = gather_indices->shape().dimensions(1);
+ int64 index_vector_size = start_indices->shape().dimensions(1);
TF_ASSIGN_OR_RETURN(
HloInstruction * index_vector_2d,
- MakeDynamicSliceHlo(gather_indices, index_into_gather_indices,
+ MakeDynamicSliceHlo(start_indices, index_into_start_indices,
{1, index_vector_size}));
TF_ASSIGN_OR_RETURN(index_vector,
@@ -194,26 +192,26 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice,
MakeDynamicSliceHlo(operand, gathered_slice_start,
- gather.gather_window_bounds()));
+ gather.gather_slice_sizes()));
TF_ASSIGN_OR_RETURN(
- HloInstruction * gathered_slice_with_dims_elided,
+ HloInstruction* const gathered_slice_with_dims_collapsed,
ElideDegenerateDims(gathered_slice,
- AsInt64Slice(dim_numbers.elided_window_dims())));
+ AsInt64Slice(dim_numbers.collapsed_slice_dims())));
TF_ASSIGN_OR_RETURN(
- HloInstruction * gathered_slice_for_update,
- PrependDegenerateDims(gathered_slice_with_dims_elided, 1));
+ HloInstruction* const gathered_slice_for_update,
+ PrependDegenerateDims(gathered_slice_with_dims_collapsed, 1));
TF_ASSIGN_OR_RETURN(
- HloInstruction * index_vector_into_accumulator,
+ HloInstruction* const index_vector_into_accumulator,
PadVectorWithZeros(
induction_var_as_vector, /*zeros_to_prepend=*/0,
/*zeros_to_append=*/
- gathered_slice_with_dims_elided->shape().dimensions_size()));
+ gathered_slice_with_dims_collapsed->shape().dimensions_size()));
TF_ASSIGN_OR_RETURN(
- HloInstruction * updated_accumulator,
+ HloInstruction* const updated_accumulator,
MakeDynamicUpdateSliceHlo(output_accumulator, gathered_slice_for_update,
index_vector_into_accumulator));
@@ -221,19 +219,19 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
// WhileUtil::MakeCountedLoop functions takes care of the induction variable
// and the while loop exit condition.
return StatusOr<std::vector<HloInstruction*>>{
- {operand, gather_indices, updated_accumulator}};
+ {operand, start_indices, updated_accumulator}};
}
static StatusOr<HloInstruction*> CreateGatherLoopAccumulatorInitValue(
HloComputation* computation, PrimitiveType element_type,
- ArraySlice<int64> window_bounds, int64 gather_loop_trip_count,
+ ArraySlice<int64> slice_sizes, int64 gather_loop_trip_count,
const GatherDimensionNumbers& dim_numbers) {
std::vector<int64> accumulator_state_shape_dims;
- accumulator_state_shape_dims.reserve(1 + window_bounds.size());
+ accumulator_state_shape_dims.reserve(1 + slice_sizes.size());
accumulator_state_shape_dims.push_back(gather_loop_trip_count);
- for (int64 i = 0; i < window_bounds.size(); i++) {
- if (!c_binary_search(dim_numbers.elided_window_dims(), i)) {
- accumulator_state_shape_dims.push_back(window_bounds[i]);
+ for (int64 i = 0; i < slice_sizes.size(); i++) {
+ if (!c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
+ accumulator_state_shape_dims.push_back(slice_sizes[i]);
}
}
return BroadcastZeros(computation, element_type,
@@ -241,23 +239,23 @@ static StatusOr<HloInstruction*> CreateGatherLoopAccumulatorInitValue(
}
// `accumulator` is almost the tensor the gather operation would have produced,
-// except that it has the dimensions in the wrong order -- the gather dimensions
-// are the major dimensions and the window dimensions are the minor dimensions.
+// except that it has the dimensions in the wrong order -- the batch dimensions
+// are the major dimensions and the offset dimensions are the minor dimensions.
// Fix this up with a transpose.
-static StatusOr<HloInstruction*> PermuteGatherAndWindowDims(
- HloInstruction* accumulator, ArraySlice<int64> output_window_dims,
+static StatusOr<HloInstruction*> PermuteBatchAndOffsetDims(
+ HloInstruction* accumulator, ArraySlice<int64> offset_dims,
int64 output_rank) {
std::vector<int64> permutation;
permutation.reserve(output_rank);
- int64 gather_idx_counter = 0;
- int64 window_idx_counter = output_rank - output_window_dims.size();
+ int64 batch_idx_counter = 0;
+ int64 offset_idx_counter = output_rank - offset_dims.size();
for (int64 i = 0; i < output_rank; i++) {
- bool is_window_dim = c_binary_search(output_window_dims, i);
- if (is_window_dim) {
- permutation.push_back(window_idx_counter++);
+ bool is_offset_dim = c_binary_search(offset_dims, i);
+ if (is_offset_dim) {
+ permutation.push_back(offset_idx_counter++);
} else {
- permutation.push_back(gather_idx_counter++);
+ permutation.push_back(batch_idx_counter++);
}
}
@@ -268,11 +266,11 @@ static StatusOr<HloInstruction*> PermuteGatherAndWindowDims(
//
// We follow the following steps in sequence:
//
-// 1. We canonicalize the gather_indices tensor such that it has rank
+// 1. We canonicalize the start_indices tensor such that it has rank
// 2 (i.e. is a matrix) where each row is an index vector into the
// operand.
// 2. We iterate over the set of indices in the canonicalized
-// gather_indices tensor using a while loop, accumulating slices
+// start_indices tensor using a while loop, accumulating slices
// of the operand tensor into an accumulator using
// DynamicUpdateSlice.
// 3. The accumulator result from the while loop from (2) is then
@@ -287,11 +285,11 @@ static StatusOr<HloInstruction*> PermuteGatherAndWindowDims(
// operand = s32[3,3] parameter(0)
// indices = s32[2,2] parameter(1)
// ROOT gather = s32[2,3,2] gather(operand, indices),
-// output_window_dims={1},
-// elided_window_dims={1},
-// gather_dims_to_operand_dims={1},
+// offset_dims={1},
+// collapsed_slice_dims={1},
+// start_index_map={1},
// index_vector_dim=2,
-// window_bounds={3, 1}
+// slice_sizes={3, 1}
// }
//
// We'd first reshape indices to s32[4,1], where each row is an index
@@ -305,8 +303,8 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather(
HloComputation* computation = gather_instr->parent();
HloInstruction* operand = gather_instr->mutable_operand(0);
- HloInstruction* gather_indices = gather_instr->mutable_operand(1);
- const Shape& gather_indices_shape = gather_indices->shape();
+ HloInstruction* start_indices = gather_instr->mutable_operand(1);
+ const Shape& start_indices_shape = start_indices->shape();
const Shape& output_shape = gather_instr->shape();
int64 output_rank = output_shape.dimensions_size();
@@ -314,9 +312,9 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather(
gather_instr->gather_dimension_numbers();
int64 gather_loop_trip_count = 1;
- for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) {
+ for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) {
if (i != dim_numbers.index_vector_dim()) {
- gather_loop_trip_count *= gather_indices_shape.dimensions(i);
+ gather_loop_trip_count *= start_indices_shape.dimensions(i);
}
}
@@ -327,24 +325,24 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather(
gather_instr->ToString().c_str());
}
- TF_ASSIGN_OR_RETURN(HloInstruction * canonical_gather_indices,
- CanonicalizeGatherIndices(
- gather_indices, dim_numbers.index_vector_dim()));
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * canonical_start_indices,
+ CanonicalizeGatherIndices(start_indices, dim_numbers.index_vector_dim()));
CHECK_EQ(gather_loop_trip_count,
- canonical_gather_indices->shape().dimensions(0));
+ canonical_start_indices->shape().dimensions(0));
TF_ASSIGN_OR_RETURN(
HloInstruction * accumulator_init,
CreateGatherLoopAccumulatorInitValue(
computation, output_shape.element_type(),
- gather_instr->gather_window_bounds(), gather_loop_trip_count,
+ gather_instr->gather_slice_sizes(), gather_loop_trip_count,
gather_instr->gather_dimension_numbers()));
StatusOr<std::vector<HloInstruction*>> gather_loop_result_or_error =
WhileUtil::MakeCountedLoop(
computation, gather_loop_trip_count,
- {operand, canonical_gather_indices, accumulator_init},
+ {operand, canonical_start_indices, accumulator_init},
[&](HloInstruction* indvar,
const std::vector<HloInstruction*>& loop_state) {
return GatherLoopBody(*gather_instr, indvar, loop_state);
@@ -356,13 +354,13 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather(
HloInstruction* accumulator_result = gather_loop_result.back();
TF_ASSIGN_OR_RETURN(
- HloInstruction * accumulator_with_output_gather_dims_decanonicalized,
- AdjustGatherDimsInAccumulator(gather_indices->shape(), accumulator_result,
- dim_numbers.index_vector_dim()));
+ HloInstruction* const accumulator_with_batch_dims_decanonicalized,
+ AdjustBatchDimsInAccumulator(start_indices->shape(), accumulator_result,
+ dim_numbers.index_vector_dim()));
- return PermuteGatherAndWindowDims(
- accumulator_with_output_gather_dims_decanonicalized,
- AsInt64Slice(dim_numbers.output_window_dims()), output_rank);
+ return PermuteBatchAndOffsetDims(accumulator_with_batch_dims_decanonicalized,
+ AsInt64Slice(dim_numbers.offset_dims()),
+ output_rank);
}
StatusOr<bool> GatherExpander::Run(HloModule* module) {
diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc
index 020ffcd106..141dd4d6f1 100644
--- a/tensorflow/compiler/xla/service/gather_expander_test.cc
+++ b/tensorflow/compiler/xla/service/gather_expander_test.cc
@@ -28,11 +28,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2147483647,5] parameter(1)
ROOT gather = s32[2147483647,3,5] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
@@ -55,11 +55,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[3,2] gather(operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 9d24b42401..fa218657fe 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -139,7 +139,7 @@ message HloInstructionProto {
// Gather dimension numbers.
xla.GatherDimensionNumbers gather_dimension_numbers = 33;
- repeated int64 gather_window_bounds = 34;
+ repeated int64 gather_slice_sizes = 34;
// Compute Host.
string channel_name = 41;
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 51353eea6e..36d6a2eed6 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -555,43 +555,39 @@ Status HloEvaluator::HandleTuple(HloInstruction* tuple) {
return Status::OK();
}
-// Returns an ShapeUtil::IndexIterationSpace that iterates over the output
-// gather dimensions while keeping the rest of the output dimensions clamped to
-// 0.
-ShapeUtil::IndexIterationSpace IterationSpaceForOutputGatherIndices(
+// Returns an ShapeUtil::IndexIterationSpace that iterates over the output batch
+// dimensions while keeping the rest of the output dimensions clamped to 0.
+ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices(
const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) {
int64 output_rank = output_shape.dimensions_size();
std::vector<int64> index_base(output_rank, 0);
std::vector<int64> index_count;
index_count.reserve(output_rank);
for (int64 i = 0; i < output_rank; i++) {
- bool is_output_gather_dim =
- !c_binary_search(dim_numbers.output_window_dims(), i);
- index_count.push_back(is_output_gather_dim ? output_shape.dimensions(i)
- : 1);
+ bool is_output_batch_dim = !c_binary_search(dim_numbers.offset_dims(), i);
+ index_count.push_back(is_output_batch_dim ? output_shape.dimensions(i) : 1);
}
return {std::move(index_base), std::move(index_count),
std::vector<int64>(output_rank, 1)};
}
-// Return an ShapeUtil::IndexIterationSpace that iterates over the output window
+// Return an ShapeUtil::IndexIterationSpace that iterates over the output slice
// dimensions while keeping the rest of the output dimensions clamped to 0.
-ShapeUtil::IndexIterationSpace IterationSpaceForOutputWindowIndices(
- int64 output_rank, ArraySlice<int64> window_bounds,
+ShapeUtil::IndexIterationSpace IterationSpaceForOutputOffsetIndices(
+ int64 output_rank, ArraySlice<int64> slice_sizes,
const GatherDimensionNumbers& dim_numbers) {
std::vector<int64> index_base(output_rank, 0);
std::vector<int64> index_count(output_rank, 1);
- int64 window_bounds_idx = 0;
+ int64 slice_sizes_idx = 0;
for (int64 i = 0; i < output_rank; i++) {
- bool is_output_window_dim =
- c_binary_search(dim_numbers.output_window_dims(), i);
+ bool is_output_window_dim = c_binary_search(dim_numbers.offset_dims(), i);
if (is_output_window_dim) {
- while (c_binary_search(dim_numbers.elided_window_dims(),
- window_bounds_idx)) {
- window_bounds_idx++;
+ while (c_binary_search(dim_numbers.collapsed_slice_dims(),
+ slice_sizes_idx)) {
+ slice_sizes_idx++;
}
- index_count[i] = window_bounds[window_bounds_idx++];
+ index_count[i] = slice_sizes[slice_sizes_idx++];
}
}
@@ -599,30 +595,30 @@ ShapeUtil::IndexIterationSpace IterationSpaceForOutputWindowIndices(
std::vector<int64>(output_rank, 1)};
}
-// This functor computes the contribution of gather_indices to an input index
+// This functor computes the contribution of start_indices to an input index
// corresponding to an output index. That is, given an output index I, it picks
-// out the gather output indices in I and uses them to look up a gather index,
-// G, from the gather indices tensor, and expands G into the input space
-// according to gather_dims_to_operand_dims.
-class OutputGatherIndexToInputIndex {
+// out the batch indices in I and uses them to look up a starting index, G, from
+// the start indices tensor, and expands G into the input space according to
+// start_index_map.
+class OutputBatchIndexToInputIndex {
public:
// The constructor does some setup work that is amortized across all
// iterations.
- explicit OutputGatherIndexToInputIndex(
+ explicit OutputBatchIndexToInputIndex(
const GatherDimensionNumbers* dim_numbers, const Shape& input_shape,
- const Shape& output_shape, const Literal* gather_indices)
- : dim_numbers_(*dim_numbers), gather_indices_(*gather_indices) {
+ const Shape& output_shape, const Literal* start_indices)
+ : dim_numbers_(*dim_numbers), start_indices_(*start_indices) {
for (int64 i = 0; i < output_shape.dimensions_size(); i++) {
- output_dim_is_gather_dims_.push_back(
- !c_binary_search(dim_numbers_.output_window_dims(), i));
+ output_dim_is_batch_dims_.push_back(
+ !c_binary_search(dim_numbers_.offset_dims(), i));
}
for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
int64 index_of_input_dim_in_index_vector =
- std::distance(dim_numbers_.gather_dims_to_operand_dims().begin(),
- c_find(dim_numbers_.gather_dims_to_operand_dims(), i));
+ std::distance(dim_numbers_.start_index_map().begin(),
+ c_find(dim_numbers_.start_index_map(), i));
if (index_of_input_dim_in_index_vector ==
- dim_numbers_.gather_dims_to_operand_dims_size()) {
+ dim_numbers_.start_index_map_size()) {
input_dim_value_to_index_vector_.push_back(-1);
} else {
input_dim_value_to_index_vector_.push_back(
@@ -630,14 +626,14 @@ class OutputGatherIndexToInputIndex {
}
}
- index_vector_index_.resize(gather_indices_.shape().dimensions_size());
+ index_vector_index_.resize(start_indices_.shape().dimensions_size());
input_index_.resize(input_shape.dimensions_size());
int64 index_vector_size =
- gather_indices_.shape().dimensions(dim_numbers_.index_vector_dim());
+ start_indices_.shape().dimensions(dim_numbers_.index_vector_dim());
index_vector_.resize(index_vector_size);
}
- // Returns the contribution of gather_indices to the input index corresponding
+ // Returns the contribution of start_indices to the input index corresponding
// to output_index. See gather_inner_loop_body.
//
// This is conceptually a stateless transformation from output_index to the
@@ -659,7 +655,7 @@ class OutputGatherIndexToInputIndex {
}
private:
- // Propagates the gather index dimensions from the output index into
+ // Propagates the batch dimensions from the output index into
// index_vector_index_ by mutating index_vector_index_ in place. Does not
// update the dim_numbers.index_vector_dim() dimension -- that's the dimension
// we iterate over in FetchIndexVector.
@@ -667,7 +663,7 @@ class OutputGatherIndexToInputIndex {
ArraySlice<int64> output_index) {
int64 index_vector_index_i = 0;
for (int64 i = 0, e = output_index.size(); i < e; i++) {
- if (!output_dim_is_gather_dims_[i]) {
+ if (!output_dim_is_batch_dims_[i]) {
continue;
}
@@ -679,14 +675,14 @@ class OutputGatherIndexToInputIndex {
}
}
- // Populates index_vector_ by iterating over gather_indices_ according to
+ // Populates index_vector_ by iterating over start_indices_ according to
// index_vector_index_.
Status FetchIndexVector() {
int64 index_vector_dim = dim_numbers_.index_vector_dim();
for (int64 i = 0, e = index_vector_.size(); i < e; i++) {
index_vector_index_[index_vector_dim] = i;
- TF_ASSIGN_OR_RETURN(index_vector_[i], gather_indices_.GetIntegralAsS64(
- index_vector_index_));
+ TF_ASSIGN_OR_RETURN(index_vector_[i],
+ start_indices_.GetIntegralAsS64(index_vector_index_));
}
return Status::OK();
}
@@ -708,15 +704,15 @@ class OutputGatherIndexToInputIndex {
// PropagateIndexVectorToInputIndex.
std::vector<int64> input_dim_value_to_index_vector_;
- // output_dim_is_gather_dims_[i] is true iff the output index i is a gather
+ // output_dim_is_batch_dims_[i] is true iff the output index i is a gather
// dimension.
- std::vector<bool> output_dim_is_gather_dims_;
+ std::vector<bool> output_dim_is_batch_dims_;
- // The buffer into which we construct an index into gather_indices_ to fetch
+ // The buffer into which we construct an index into start_indices_ to fetch
// the index vector.
std::vector<int64> index_vector_index_;
- // The index vector fetched from gather_indices_.
+ // The index vector fetched from start_indices_.
std::vector<int64> index_vector_;
// The result computed by this functor. operator() returns an ArraySlice into
@@ -724,24 +720,23 @@ class OutputGatherIndexToInputIndex {
std::vector<int64> input_index_;
const GatherDimensionNumbers& dim_numbers_;
- const Literal& gather_indices_;
+ const Literal& start_indices_;
};
-// This functor computes the contribution of the window indices in an output
+// This functor computes the contribution of the offset indices in an output
// index to an input index. That is, given an output index I it picks out the
-// output window indices in I and expands it into a window index into the input
-// shape.
-class OutputWindowIndexToInputIndex {
+// output offset indices in I and expands it into an index into the input shape.
+class OutputOffsetIndexToInputIndex {
public:
// The constructor does some setup work that is amortized across all
// iterations.
- explicit OutputWindowIndexToInputIndex(
+ explicit OutputOffsetIndexToInputIndex(
const GatherDimensionNumbers& dim_numbers, const Shape& input_shape,
const Shape& output_shape) {
std::vector<int64> window_index_to_output_index;
int64 output_index_count = 0;
for (int64 i = 0; i < output_shape.dimensions_size(); i++) {
- if (c_binary_search(dim_numbers.output_window_dims(), i)) {
+ if (c_binary_search(dim_numbers.offset_dims(), i)) {
window_index_to_output_index.push_back(output_index_count++);
} else {
output_index_count++;
@@ -750,7 +745,7 @@ class OutputWindowIndexToInputIndex {
int64 window_dim_count = 0;
for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
- if (c_binary_search(dim_numbers.elided_window_dims(), i)) {
+ if (c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
input_dim_value_to_output_index_.push_back(-1);
} else {
input_dim_value_to_output_index_.push_back(
@@ -808,20 +803,20 @@ class OutputWindowIndexToInputIndex {
// Rehapes the gather indices input to have a trailing degenerate `1` dimension
// if necessary. Hands over the ownership of the newly created literal (if
-// there is one) to `reshaped_gather_indices`.
+// there is one) to `reshaped_start_indices`.
static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
- int64 index_vector_dim, const Literal& gather_indices,
- std::unique_ptr<Literal>* reshaped_gather_indices) {
- if (gather_indices.shape().dimensions_size() != index_vector_dim) {
- return std::cref(gather_indices);
+ int64 index_vector_dim, const Literal& start_indices,
+ std::unique_ptr<Literal>* reshaped_start_indices) {
+ if (start_indices.shape().dimensions_size() != index_vector_dim) {
+ return std::cref(start_indices);
}
- std::vector<int64> new_shape(gather_indices.shape().dimensions().begin(),
- gather_indices.shape().dimensions().end());
+ std::vector<int64> new_shape(start_indices.shape().dimensions().begin(),
+ start_indices.shape().dimensions().end());
new_shape.push_back(1);
- TF_ASSIGN_OR_RETURN(*reshaped_gather_indices,
- gather_indices.Reshape(new_shape));
- return std::cref(**reshaped_gather_indices);
+ TF_ASSIGN_OR_RETURN(*reshaped_start_indices,
+ start_indices.Reshape(new_shape));
+ return std::cref(**reshaped_start_indices);
}
Status HloEvaluator::HandleGather(HloInstruction* gather) {
@@ -830,34 +825,33 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
const GatherDimensionNumbers& dim_numbers =
gather->gather_dimension_numbers();
const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0));
- std::unique_ptr<Literal> reshaped_gather_indices;
+ std::unique_ptr<Literal> reshaped_start_indices;
TF_ASSIGN_OR_RETURN(
- const Literal& gather_indices,
+ const Literal& start_indices,
ReshapedGatherIndices(dim_numbers.index_vector_dim(),
GetEvaluatedLiteralFor(gather->operand(1)),
- &reshaped_gather_indices));
+ &reshaped_start_indices));
// We iterate over the gather dimensions in the output shape in an outer loop
// nest, and iterate over the window dimensions in the output shape in an
// inner loop nest.
- ShapeUtil::IndexIterationSpace gather_indices_iteration_space =
- IterationSpaceForOutputGatherIndices(shape, dim_numbers);
- ShapeUtil::IndexIterationSpace window_indices_iteration_space =
- IterationSpaceForOutputWindowIndices(
- shape.dimensions_size(), gather->gather_window_bounds(), dim_numbers);
+ ShapeUtil::IndexIterationSpace start_indices_iteration_space =
+ IterationSpaceForOutputBatchIndices(shape, dim_numbers);
+ ShapeUtil::IndexIterationSpace offset_indices_iteration_space =
+ IterationSpaceForOutputOffsetIndices(
+ shape.dimensions_size(), gather->gather_slice_sizes(), dim_numbers);
// Scratch buffers that hold an index in the output shape and the
// corresponding index in the input shape.
std::vector<int64> input_index(operand.shape().dimensions_size());
std::vector<int64> output_index(gather->shape().dimensions_size());
- std::vector<int64> input_gather_index_clamped(
- operand.shape().dimensions_size());
+ std::vector<int64> input_index_clamped(operand.shape().dimensions_size());
- OutputGatherIndexToInputIndex output_gather_index_to_input_index(
+ OutputBatchIndexToInputIndex output_batch_index_to_input_index(
&gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
- /*output_shape=*/shape, &gather_indices);
- OutputWindowIndexToInputIndex output_window_index_to_input_index(
+ /*output_shape=*/shape, &start_indices);
+ OutputOffsetIndexToInputIndex output_offset_index_to_input_index(
gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
/*output_shape=*/shape);
@@ -869,29 +863,29 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
ArraySlice<int64> output_gather_index) -> StatusOr<bool> {
TF_ASSIGN_OR_RETURN(
ArraySlice<int64> input_window_index,
- output_window_index_to_input_index(output_window_index));
+ output_offset_index_to_input_index(output_window_index));
for (int i = 0, e = output_index.size(); i < e; i++) {
output_index[i] = output_gather_index[i] + output_window_index[i];
DCHECK_LT(output_index[i], shape.dimensions(i));
}
for (int i = 0, e = input_gather_index.size(); i < e; i++) {
int64 output_dim =
- output_window_index_to_input_index.input_dim_value_to_output_index(i);
+ output_offset_index_to_input_index.input_dim_value_to_output_index(i);
// If 'output_dim' is -1, it means 'i' is an elided window dim. This means
// we set the iteration index to 0, so for the purpose of the following
// calculations we can consider the output dimension size to be 1.
int64 output_dim_size =
output_dim == -1 ? 1 : shape.dimensions(output_dim);
// Clamp the gather index so that the gather region fits in the operand.
- // input_gather_index_clamped[i] = clamp(input_gather_index[i], 0,
+ // input_index_clamped[i] = clamp(input_gather_index[i], 0,
// operand_shape.dimensions(i) -
// output_dim_size);
- input_gather_index_clamped[i] =
+ input_index_clamped[i] =
std::min(operand_shape.dimensions(i) - output_dim_size,
std::max(0LL, input_gather_index[i]));
}
for (int i = 0, e = input_index.size(); i < e; i++) {
- input_index[i] = input_gather_index_clamped[i] + input_window_index[i];
+ input_index[i] = input_index_clamped[i] + input_window_index[i];
DCHECK_GE(input_index[i], 0);
DCHECK_LT(input_index[i], operand_shape.dimensions(i));
}
@@ -902,18 +896,17 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
auto gather_outer_loop_body =
[&](ArraySlice<int64> output_gather_index) -> StatusOr<bool> {
- TF_ASSIGN_OR_RETURN(
- ArraySlice<int64> input_gather_index,
- output_gather_index_to_input_index(output_gather_index));
+ TF_ASSIGN_OR_RETURN(ArraySlice<int64> input_gather_index,
+ output_batch_index_to_input_index(output_gather_index));
TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
- shape, window_indices_iteration_space,
+ shape, offset_indices_iteration_space,
std::bind(gather_inner_loop_body, std::placeholders::_1,
input_gather_index, output_gather_index)));
return true;
};
TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
- shape, gather_indices_iteration_space, gather_outer_loop_body));
+ shape, start_indices_iteration_space, gather_outer_loop_body));
evaluated_[gather] = std::move(result);
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 1cef3549e0..1394be68e4 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -1826,21 +1826,20 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[2,3] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1, 3}
+ slice_sizes={1, 3}
}
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
EXPECT_TRUE(LiteralTestUtil::Equal(
*LiteralUtil::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) {
@@ -1851,21 +1850,20 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[3,2] gather(operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
}
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
EXPECT_TRUE(LiteralTestUtil::Equal(
*LiteralUtil::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) {
@@ -1876,22 +1874,22 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,3,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
}
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
EXPECT_TRUE(LiteralTestUtil::Equal(
*LiteralUtil::CreateR3<int32>(
{{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) {
@@ -1902,11 +1900,11 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
}
)";
ParseAndVerifyModule(hlo_text);
@@ -1914,11 +1912,11 @@ ENTRY main {
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
EXPECT_TRUE(
LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{-1, 1}, {-4, 4}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest,
@@ -1930,11 +1928,11 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
}
)";
ParseAndVerifyModule(hlo_text);
@@ -1942,11 +1940,11 @@ ENTRY main {
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
EXPECT_TRUE(
LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{-2, 2}, {-1, 1}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) {
@@ -1957,21 +1955,20 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[1,1] gather(operand, indices),
- output_window_dims={0,1},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={0,1},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
}
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({1, 1});
EXPECT_TRUE(
LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{5}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) {
@@ -1982,21 +1979,21 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,1,1] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
}
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
EXPECT_TRUE(
LiteralTestUtil::Equal(*LiteralUtil::CreateR3<int32>({{{8}}, {{5}}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) {
@@ -2007,20 +2004,19 @@ ENTRY main {
operand = s32[3,0] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[2,0] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1, 0}
+ slice_sizes={1, 0}
}
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
EXPECT_TRUE(
LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{}, {}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) {
@@ -2031,21 +2027,21 @@ ENTRY main {
operand = s32[3] parameter(0)
indices = s32[2,2,1] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
- output_window_dims={},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1}
+ slice_sizes={1}
}
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
EXPECT_TRUE(
LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{0, 1}, {2, 1}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 4aaef1941b..57e75cf931 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -392,13 +392,12 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
<< "Gather instruction should have GatherDimensionNumbers set.";
std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers =
MakeUnique<GatherDimensionNumbers>(proto.gather_dimension_numbers());
- std::vector<int64> gather_window_bounds;
- for (int64 bound : proto.gather_window_bounds()) {
- gather_window_bounds.push_back(bound);
+ std::vector<int64> gather_slice_sizes;
+ for (int64 bound : proto.gather_slice_sizes()) {
+ gather_slice_sizes.push_back(bound);
}
- instruction =
- CreateGather(proto.shape(), operands(0), operands(1),
- *gather_dimension_numbers, gather_window_bounds);
+ instruction = CreateGather(proto.shape(), operands(0), operands(1),
+ *gather_dimension_numbers, gather_slice_sizes);
break;
}
case HloOpcode::kScatter: {
@@ -1078,11 +1077,11 @@ bool HloInstruction::HasSideEffect() const {
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateGather(
- const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices,
+ const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds) {
- return MakeUnique<HloGatherInstruction>(shape, operand, gather_indices,
- gather_dim_numbers, window_bounds);
+ tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ return MakeUnique<HloGatherInstruction>(shape, operand, start_indices,
+ gather_dim_numbers, slice_sizes);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateScatter(
@@ -3226,9 +3225,8 @@ const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const {
return Cast<HloGatherInstruction>(this)->gather_dimension_numbers();
}
-tensorflow::gtl::ArraySlice<int64> HloInstruction::gather_window_bounds()
- const {
- return Cast<HloGatherInstruction>(this)->gather_window_bounds();
+tensorflow::gtl::ArraySlice<int64> HloInstruction::gather_slice_sizes() const {
+ return Cast<HloGatherInstruction>(this)->gather_slice_sizes();
}
const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers()
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index b3eee90099..8d8f149ee3 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -667,9 +667,9 @@ class HloInstruction {
static std::unique_ptr<HloInstruction> CreateGather(
const Shape& shape, HloInstruction* operand,
- HloInstruction* gather_indices,
+ HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
static std::unique_ptr<HloInstruction> CreateScatter(
const Shape& shape, HloInstruction* operand,
@@ -1489,8 +1489,8 @@ class HloInstruction {
// Delegates to HloGatherInstruction::gather_dimension_numbers.
const GatherDimensionNumbers& gather_dimension_numbers() const;
- // Delegates to HloGatherInstruction::gather_window_bounds.
- tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const;
+ // Delegates to HloGatherInstruction::gather_slice_sizes.
+ tensorflow::gtl::ArraySlice<int64> gather_slice_sizes() const;
// Delegates to HloScatterInstruction::scatter_dimension_numbers().
const ScatterDimensionNumbers& scatter_dimension_numbers() const;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 8a694dde80..504b13043f 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -1355,7 +1355,7 @@ TEST_F(HloInstructionTest, Stringification) {
TEST_F(HloInstructionTest, StringifyGather_0) {
Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
- Shape gather_indices_tensor_shape =
+ Shape start_indices_tensor_shape =
ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
Shape gather_result_shape =
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26});
@@ -1363,19 +1363,18 @@ TEST_F(HloInstructionTest, StringifyGather_0) {
HloComputation::Builder builder("Gather");
HloInstruction* input = builder.AddInstruction(
HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
- HloInstruction* gather_indices =
+ HloInstruction* start_indices =
builder.AddInstruction(HloInstruction::CreateParameter(
- 1, gather_indices_tensor_shape, "gather_indices"));
-
- HloInstruction* gather_instruction =
- builder.AddInstruction(HloInstruction::CreateGather(
- gather_result_shape, input, gather_indices,
- HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
- /*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26}));
+ 1, start_indices_tensor_shape, "start_indices"));
+
+ HloInstruction* gather_instruction = builder.AddInstruction(
+ HloInstruction::CreateGather(gather_result_shape, input, start_indices,
+ HloGatherInstruction::MakeGatherDimNumbers(
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4),
+ /*slice_sizes=*/{30, 29, 28, 27, 26}));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@@ -1383,15 +1382,15 @@ TEST_F(HloInstructionTest, StringifyGather_0) {
EXPECT_EQ(gather_instruction->ToString(),
"%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
"gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
- "s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), "
- "output_window_dims={4,5,6,7,8}, elided_window_dims={}, "
- "gather_dims_to_operand_dims={0,1,2,3,4}, "
- "index_vector_dim=4, window_bounds={30,29,28,27,26}");
+ "s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), "
+ "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, "
+ "start_index_map={0,1,2,3,4}, "
+ "index_vector_dim=4, slice_sizes={30,29,28,27,26}");
}
TEST_F(HloInstructionTest, StringifyGather_1) {
Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
- Shape gather_indices_tensor_shape =
+ Shape start_indices_tensor_shape =
ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
Shape gather_result_shape =
ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26});
@@ -1399,19 +1398,18 @@ TEST_F(HloInstructionTest, StringifyGather_1) {
HloComputation::Builder builder("Gather");
HloInstruction* input = builder.AddInstruction(
HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
- HloInstruction* gather_indices =
+ HloInstruction* start_indices =
builder.AddInstruction(HloInstruction::CreateParameter(
- 1, gather_indices_tensor_shape, "gather_indices"));
-
- HloInstruction* gather_instruction =
- builder.AddInstruction(HloInstruction::CreateGather(
- gather_result_shape, input, gather_indices,
- HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
- /*index_vector_dim=*/2),
- /*window_bounds=*/{30, 29, 28, 27, 26}));
+ 1, start_indices_tensor_shape, "start_indices"));
+
+ HloInstruction* gather_instruction = builder.AddInstruction(
+ HloInstruction::CreateGather(gather_result_shape, input, start_indices,
+ HloGatherInstruction::MakeGatherDimNumbers(
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/2),
+ /*slice_sizes=*/{30, 29, 28, 27, 26}));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@@ -1419,10 +1417,10 @@ TEST_F(HloInstructionTest, StringifyGather_1) {
EXPECT_EQ(gather_instruction->ToString(),
"%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
"gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
- "s64[10,9,5,7,6]{4,3,2,1,0} %gather_indices), "
- "output_window_dims={4,5,6,7,8}, elided_window_dims={}, "
- "gather_dims_to_operand_dims={0,1,2,3,4}, "
- "index_vector_dim=2, window_bounds={30,29,28,27,26}");
+ "s64[10,9,5,7,6]{4,3,2,1,0} %start_indices), "
+ "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, "
+ "start_index_map={0,1,2,3,4}, "
+ "index_vector_dim=2, slice_sizes={30,29,28,27,26}");
}
TEST_F(HloInstructionTest, StringifyScatter) {
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 233cdda7b0..4fdf4360e6 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -1965,51 +1965,50 @@ HloDynamicSliceInstruction::CloneWithNewOperandsImpl(
}
HloGatherInstruction::HloGatherInstruction(
- const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices,
+ const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds)
+ tensorflow::gtl::ArraySlice<int64> slice_sizes)
: HloInstruction(HloOpcode::kGather, shape) {
AppendOperand(operand);
- AppendOperand(gather_indices);
+ AppendOperand(start_indices);
gather_dimension_numbers_ =
MakeUnique<GatherDimensionNumbers>(gather_dim_numbers);
- c_copy(window_bounds, std::back_inserter(gather_window_bounds_));
+ c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_));
}
string HloGatherInstruction::GatherDimensionNumbersToString() const {
CHECK(gather_dimension_numbers_ != nullptr);
- string output_window_dims =
- StrCat("output_window_dims={",
- Join(gather_dimension_numbers_->output_window_dims(), ","), "}");
- string elided_window_dims =
- StrCat("elided_window_dims={",
- Join(gather_dimension_numbers_->elided_window_dims(), ","), "}");
- string gather_dims_to_operand_dims = StrCat(
- "gather_dims_to_operand_dims={",
- Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}");
+ string offset_dims =
+ StrCat("offset_dims={",
+ Join(gather_dimension_numbers_->offset_dims(), ","), "}");
+ string collapsed_slice_dims =
+ StrCat("collapsed_slice_dims={",
+ Join(gather_dimension_numbers_->collapsed_slice_dims(), ","), "}");
+ string start_index_map =
+ StrCat("start_index_map={",
+ Join(gather_dimension_numbers_->start_index_map(), ","), "}");
string index_vector_dim = StrCat(
"index_vector_dim=", gather_dimension_numbers_->index_vector_dim());
return Join<std::initializer_list<string>>(
- {output_window_dims, elided_window_dims, gather_dims_to_operand_dims,
- index_vector_dim},
+ {offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim},
", ");
}
/* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers(
- tensorflow::gtl::ArraySlice<int64> output_window_dims,
- tensorflow::gtl::ArraySlice<int64> elided_window_dims,
- tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims,
+ tensorflow::gtl::ArraySlice<int64> offset_dims,
+ tensorflow::gtl::ArraySlice<int64> collapsed_slice_dims,
+ tensorflow::gtl::ArraySlice<int64> start_index_map,
int64 index_vector_dim) {
GatherDimensionNumbers gather_dim_numbers;
- for (int64 output_window_dim : output_window_dims) {
- gather_dim_numbers.add_output_window_dims(output_window_dim);
+ for (int64 output_window_dim : offset_dims) {
+ gather_dim_numbers.add_offset_dims(output_window_dim);
}
- for (int64 elided_window_dim : elided_window_dims) {
- gather_dim_numbers.add_elided_window_dims(elided_window_dim);
+ for (int64 elided_window_dim : collapsed_slice_dims) {
+ gather_dim_numbers.add_collapsed_slice_dims(elided_window_dim);
}
- for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) {
- gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim);
+ for (int64 gather_dim_to_input_dim : start_index_map) {
+ gather_dim_numbers.add_start_index_map(gather_dim_to_input_dim);
}
gather_dim_numbers.set_index_vector_dim(index_vector_dim);
@@ -2019,8 +2018,8 @@ string HloGatherInstruction::GatherDimensionNumbersToString() const {
HloInstructionProto HloGatherInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
*proto.mutable_gather_dimension_numbers() = gather_dimension_numbers();
- for (int64 bound : gather_window_bounds()) {
- proto.add_gather_window_bounds(bound);
+ for (int64 bound : gather_slice_sizes()) {
+ proto.add_gather_slice_sizes(bound);
}
return proto;
}
@@ -2028,7 +2027,7 @@ HloInstructionProto HloGatherInstruction::ToProto() const {
std::vector<string> HloGatherInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
return {GatherDimensionNumbersToString(),
- StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")};
+ StrCat("slice_sizes={", Join(gather_slice_sizes(), ","), "}")};
}
bool HloGatherInstruction::IdenticalSlowPath(
@@ -2039,7 +2038,7 @@ bool HloGatherInstruction::IdenticalSlowPath(
return protobuf_util::ProtobufEquals(
gather_dimension_numbers(),
casted_other.gather_dimension_numbers()) &&
- gather_window_bounds() == casted_other.gather_window_bounds();
+ gather_slice_sizes() == casted_other.gather_slice_sizes();
}
std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl(
@@ -2049,7 +2048,7 @@ std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl(
CHECK_EQ(new_operands.size(), 2);
return MakeUnique<HloGatherInstruction>(
shape, new_operands[0], new_operands[1], gather_dimension_numbers(),
- gather_window_bounds());
+ gather_slice_sizes());
}
HloScatterInstruction::HloScatterInstruction(
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 546949bc72..803dbeabeb 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -1212,15 +1212,15 @@ class HloGatherInstruction : public HloInstruction {
public:
explicit HloGatherInstruction(
const Shape& shape, HloInstruction* operand,
- HloInstruction* gather_indices,
+ HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
const GatherDimensionNumbers& gather_dimension_numbers() const {
CHECK(gather_dimension_numbers_ != nullptr);
return *gather_dimension_numbers_;
}
- tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const {
- return gather_window_bounds_;
+ tensorflow::gtl::ArraySlice<int64> gather_slice_sizes() const {
+ return gather_slice_sizes_;
}
// Returns the dump string of the gather dimension numbers.
string GatherDimensionNumbersToString() const;
@@ -1229,9 +1229,9 @@ class HloGatherInstruction : public HloInstruction {
// Creates an instance of GatherDimensionNumbers.
static GatherDimensionNumbers MakeGatherDimNumbers(
- tensorflow::gtl::ArraySlice<int64> output_window_dims,
- tensorflow::gtl::ArraySlice<int64> elided_window_dims,
- tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims,
+ tensorflow::gtl::ArraySlice<int64> offset_dims,
+ tensorflow::gtl::ArraySlice<int64> collapsed_slice_dims,
+ tensorflow::gtl::ArraySlice<int64> start_index_map,
int64 index_vector_dim);
private:
@@ -1247,7 +1247,7 @@ class HloGatherInstruction : public HloInstruction {
HloCloneContext* context) const override;
std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
- std::vector<int64> gather_window_bounds_;
+ std::vector<int64> gather_slice_sizes_;
};
class HloScatterInstruction : public HloInstruction {
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index eb48337cd7..ab57a8b07f 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -1233,22 +1233,21 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kGather: {
- optional<std::vector<tensorflow::int64>> output_window_dims;
- attrs["output_window_dims"] = {
- /*required=*/true, AttrTy::kBracedInt64List, &output_window_dims};
- optional<std::vector<tensorflow::int64>> elided_window_dims;
- attrs["elided_window_dims"] = {
- /*required=*/true, AttrTy::kBracedInt64List, &elided_window_dims};
- optional<std::vector<tensorflow::int64>> gather_dims_to_operand_dims;
- attrs["gather_dims_to_operand_dims"] = {/*required=*/true,
- AttrTy::kBracedInt64List,
- &gather_dims_to_operand_dims};
+ optional<std::vector<tensorflow::int64>> offset_dims;
+ attrs["offset_dims"] = {/*required=*/true, AttrTy::kBracedInt64List,
+ &offset_dims};
+ optional<std::vector<tensorflow::int64>> collapsed_slice_dims;
+ attrs["collapsed_slice_dims"] = {
+ /*required=*/true, AttrTy::kBracedInt64List, &collapsed_slice_dims};
+ optional<std::vector<tensorflow::int64>> start_index_map;
+ attrs["start_index_map"] = {/*required=*/true, AttrTy::kBracedInt64List,
+ &start_index_map};
optional<tensorflow::int64> index_vector_dim;
attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
&index_vector_dim};
- optional<std::vector<tensorflow::int64>> window_bounds;
- attrs["window_bounds"] = {/*required=*/true, AttrTy::kBracedInt64List,
- &window_bounds};
+ optional<std::vector<tensorflow::int64>> slice_sizes;
+ attrs["slice_sizes"] = {/*required=*/true, AttrTy::kBracedInt64List,
+ &slice_sizes};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
@@ -1257,14 +1256,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
GatherDimensionNumbers dim_numbers =
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/*output_window_dims,
- /*elided_window_dims=*/*elided_window_dims,
- /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims,
+ /*offset_dims=*/*offset_dims,
+ /*collapsed_slice_dims=*/*collapsed_slice_dims,
+ /*start_index_map=*/*start_index_map,
/*index_vector_dim=*/*index_vector_dim);
instruction = builder->AddInstruction(HloInstruction::CreateGather(
- shape, /*operand=*/operands[0], /*gather_indices=*/operands[1],
- dim_numbers, *window_bounds));
+ shape, /*operand=*/operands[0], /*start_indices=*/operands[1],
+ dim_numbers, *slice_sizes));
break;
}
case HloOpcode::kScatter: {
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 6fa3c63d83..0d7919346b 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -752,10 +752,10 @@ ENTRY %sparse_f32_r1 () -> f32[9] {
"gather",
R"(HloModule StringifyGather
-ENTRY %Gather (input_tensor: f32[50,49,48,47,46], gather_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] {
+ENTRY %Gather (input_tensor: f32[50,49,48,47,46], start_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] {
%input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
- %gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
- ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26}
+ %start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
+ ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}
}
)"
@@ -1030,8 +1030,8 @@ R"(HloModule gather
ENTRY Gather {
input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
- gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
- ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26}
+ start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
+ ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}
}
)"
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 949a4d1110..ac1a663633 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -572,7 +572,7 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) {
gather,
ShapeInference::InferGatherShape(
gather->operand(0)->shape(), gather->operand(1)->shape(),
- gather->gather_dimension_numbers(), gather->gather_window_bounds()));
+ gather->gather_dimension_numbers(), gather->gather_slice_sizes()));
}
Status ShapeVerifier::HandleScatter(HloInstruction* scatter) {
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
index 3531b7223f..8d17c03afc 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
@@ -153,7 +153,7 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayFor(
TF_ASSIGN_OR_RETURN(
computed_array,
ComputeArrayForGather(instr->shape(), instr->gather_dimension_numbers(),
- instr->gather_window_bounds(),
+ instr->gather_slice_sizes(),
FindOrDie(cache_, instr->operand(0)),
FindOrDie(cache_, instr->operand(1))));
} else if (instr->opcode() == HloOpcode::kReshape) {
@@ -251,24 +251,23 @@ StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::FoldGatherOfGather(
StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather(
const Shape& shape, const GatherDimensionNumbers& dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds, Array* source,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes, Array* source,
Array* indices) {
if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) {
VLOG(3) << "ComputeArrayForGather: indices are not scalar";
return nullptr;
}
- CHECK_EQ(dim_numbers.gather_dims_to_operand_dims_size(), 1);
+ CHECK_EQ(dim_numbers.start_index_map_size(), 1);
- // We can also handle dim_numbers.elided_window_dims_size() == 0 here, should
- // it become relevant.
+ // We can also handle dim_numbers.collapsed_slice_dims_size() == 0 here,
+ // should it become relevant.
- if (dim_numbers.elided_window_dims_size() != 1 ||
- dim_numbers.elided_window_dims(0) !=
- dim_numbers.gather_dims_to_operand_dims(0)) {
+ if (dim_numbers.collapsed_slice_dims_size() != 1 ||
+ dim_numbers.collapsed_slice_dims(0) != dim_numbers.start_index_map(0)) {
VLOG(3) << "ComputeArrayForGather: gather operations must elide "
- "gather_dims_to_operand_dims[0] and "
- "gather_dims_to_operand_dims[0] only";
+ "start_index_map[0] and "
+ "start_index_map[0] only";
return nullptr;
}
@@ -277,21 +276,21 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather(
// arrays from an array of size [7,4,6]. We check that condition down below:
for (int64 i = 0, e = source->shape().dimensions_size(); i < e; i++) {
- if (i != dim_numbers.elided_window_dims(0) &&
- source->shape().dimensions(i) != window_bounds[i]) {
- VLOG(3) << "ComputeArrayForGather: window_bounds[" << i
+ if (i != dim_numbers.collapsed_slice_dims(0) &&
+ source->shape().dimensions(i) != slice_sizes[i]) {
+ VLOG(3) << "ComputeArrayForGather: slice_sizes[" << i
<< "] != source->shape().dimensions(" << i << ") -- "
- << source->shape().dimensions(i) << " vs. " << window_bounds[i]
- << " with dim_numbers.elided_window_dims(0) = "
- << dim_numbers.elided_window_dims(0);
+ << source->shape().dimensions(i) << " vs. " << slice_sizes[i]
+ << " with dim_numbers.collapsed_slice_dims(0) = "
+ << dim_numbers.collapsed_slice_dims(0);
return nullptr;
}
}
- int64 source_dim = dim_numbers.gather_dims_to_operand_dims(0);
+ int64 source_dim = dim_numbers.start_index_map(0);
std::vector<int64> output_dims;
for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) {
- if (!c_binary_search(dim_numbers.output_window_dims(), i)) {
+ if (!c_binary_search(dim_numbers.offset_dims(), i)) {
output_dims.push_back(i);
}
}
@@ -735,11 +734,11 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims(
// operand = s32[3,5,2] constant({...})
// indices = s32[7] parameter(0)
// gather = s32[3,2,7] gather(operand, indices),
- // output_window_dims={0,1},
- // elided_window_dims={1},
- // gather_dims_to_operand_dims={1},
+ // offset_dims={0,1},
+ // collapsed_slice_dims={1},
+ // start_index_map={1},
// index_vector_dim=1,
- // window_bounds={3,1,2}
+ // slice_sizes={3,1,2}
// reshape = s32[6,7] reshape(gather)
//
// In this case the gather maps to:
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index e923dc39f7..675eb31d26 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -265,7 +265,7 @@ class IndexedArrayAnalysis {
StatusOr<Array*> ComputeArrayForGather(
const Shape& shape, const GatherDimensionNumbers& dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds, Array* source,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes, Array* source,
Array* indices);
StatusOr<Array*> ComputeArrayForDotWithIndexedLhs(
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
index 5f4b42799b..97052edf7d 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
@@ -82,11 +82,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[5] parameter(1)
ROOT gather = s32[5,3] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,3}
+ slice_sizes={1,3}
}
)";
@@ -102,11 +102,11 @@ ENTRY main {
operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}})
indices = s32[5] parameter(0)
ROOT gather = s32[5,3] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,3}
+ slice_sizes={1,3}
}
)";
@@ -122,11 +122,11 @@ ENTRY main {
operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}})
indices = s32[5,2] parameter(0)
ROOT gather = s32[5] gather(operand, indices),
- output_window_dims={},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1}
+ slice_sizes={1,1}
}
)";
@@ -141,11 +141,11 @@ ENTRY main {
operand = s32[3,3,1] parameter(0)
indices = s32[5] parameter(1)
ROOT gather = s32[5,3] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,2},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0,2},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,3,1}
+ slice_sizes={1,3,1}
}
)";
@@ -160,11 +160,11 @@ ENTRY main {
operand = s32[3,3,1] parameter(0)
indices = s32[5] parameter(1)
ROOT gather = s32[5,2,3] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={2},
- gather_dims_to_operand_dims={0},
+ offset_dims={1,2},
+ collapsed_slice_dims={2},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={2,3,1}
+ slice_sizes={2,3,1}
}
)";
@@ -179,11 +179,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[5] parameter(1)
ROOT gather = s32[5,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,2}
+ slice_sizes={1,2}
}
)";
@@ -199,17 +199,17 @@ ENTRY main {
indices_a = s32[5] parameter(0)
indices_b = s32[2] parameter(1)
gather_a = s32[5,3] gather(operand, indices_a),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,3}
+ slice_sizes={1,3}
ROOT gather_b = s32[2,3] gather(gather_a, indices_b),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,3}
+ slice_sizes={1,3}
}
)";
@@ -228,17 +228,17 @@ ENTRY main {
indices_a = s32[5,7] parameter(1)
indices_b = s32[2] parameter(2)
gather_a = s32[5,3,7] gather(operand, indices_a),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3,1}
+ slice_sizes={3,1}
ROOT gather_b = s32[5,3,2] gather(gather_a, indices_b),
- output_window_dims={0,1},
- elided_window_dims={2},
- gather_dims_to_operand_dims={2},
+ offset_dims={0,1},
+ collapsed_slice_dims={2},
+ start_index_map={2},
index_vector_dim=1,
- window_bounds={5,3,1}
+ slice_sizes={5,3,1}
}
)";
@@ -256,17 +256,17 @@ ENTRY main {
indices_a = s32[2] parameter(1)
indices_b = s32[5,7] parameter(2)
gather_a = s32[2,6] gather(operand, indices_a),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,6}
+ slice_sizes={1,6}
ROOT gather_b = s32[5,6,7] gather(gather_a, indices_b),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1,6}
+ slice_sizes={1,6}
}
)";
@@ -284,17 +284,17 @@ ENTRY main {
indices_a = s32[5,7] parameter(1)
indices_b = s32[4,8] parameter(2)
gather_a = s32[5,3,7] gather(operand, indices_a),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3,1}
+ slice_sizes={3,1}
ROOT gather_b = s32[4,5,3,8] gather(gather_a, indices_b),
- output_window_dims={1,2},
- elided_window_dims={2},
- gather_dims_to_operand_dims={2},
+ offset_dims={1,2},
+ collapsed_slice_dims={2},
+ start_index_map={2},
index_vector_dim=2,
- window_bounds={5,3,1}
+ slice_sizes={5,3,1}
}
)";
@@ -312,11 +312,11 @@ ENTRY main {
operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}})
indices = s32[5] parameter(0)
gather = s32[5,4] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT reshape = s32[5,2,2] reshape(gather)
}
)";
@@ -333,11 +333,11 @@ ENTRY main {
operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}})
indices = s32[5,7] parameter(0)
gather = s32[5,4,7] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT reshape = s32[5,2,2,7] reshape(gather)
}
)";
@@ -358,11 +358,11 @@ ENTRY main {
{{1,2,3,4,5,6},{1,2,3,4,5,6}}})
indices = s32[5,7] parameter(0)
gather = s32[5,2,6,7] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1,2},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1,2,6}
+ slice_sizes={1,2,6}
ROOT reshape = s32[5,3,4,7] reshape(gather)
}
)";
@@ -381,11 +381,11 @@ ENTRY main {
{1,2,3,4,5,6},{1,2,3,4,5,6}})
indices = s32[1] parameter(0)
gather = s32[1,6] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,6}
+ slice_sizes={1,6}
ROOT reshape = s32[1,1,6] reshape(gather)
}
)";
@@ -408,14 +408,14 @@ ENTRY main {
operand = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 1, 2, 3 } })
i.0 = s64[1,3]{1,0} parameter(0)
- g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), output_window_dims={2},
- elided_window_dims={0}, gather_dims_to_operand_dims={0},
- index_vector_dim=2, window_bounds={1,3}
+ g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), offset_dims={2},
+ collapsed_slice_dims={0}, start_index_map={0},
+ index_vector_dim=2, slice_sizes={1,3}
i.1 = s64[1] parameter(1)
- g.1 = s32[1,1,3]{2,1,0} gather(g.0, i.1), output_window_dims={0,2},
- elided_window_dims={1}, gather_dims_to_operand_dims={1},
- index_vector_dim=1, window_bounds={1,1,3}
+ g.1 = s32[1,1,3]{2,1,0} gather(g.0, i.1), offset_dims={0,2},
+ collapsed_slice_dims={1}, start_index_map={1},
+ index_vector_dim=1, slice_sizes={1,1,3}
ROOT reshape = s32[1,3]{1,0} reshape(g.1)
}
@@ -441,11 +441,11 @@ ENTRY main {
operand = s32[1,6] constant(s32[1,6]{{1,2,3,4,5,6}})
indices = s32[1] parameter(0)
gather = s32[1,6] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,6}
+ slice_sizes={1,6}
ROOT reshape = s32[1,1,6] reshape(gather)
}
)";
@@ -469,11 +469,11 @@ ENTRY main {
{1,2,3,4,5,6},{1,2,3,4,5,6}}})
indices = s32[1] parameter(0)
gather = s32[1,1,6] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1,2},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={1,1,6}
+ slice_sizes={1,1,6}
ROOT reshape = s32[1,1,1,6] reshape(gather)
}
)";
@@ -500,11 +500,11 @@ ENTRY main {
{1,2,3,4,5,6},{1,2,3,4,5,6}})
indices = s32[1,5] parameter(0)
gather = s32[1,5,6] gather(operand, indices),
- output_window_dims={2},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={2},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1,6}
+ slice_sizes={1,6}
ROOT reshape = s32[1,1,5,6] reshape(gather)
}
)";
@@ -530,11 +530,11 @@ ENTRY main {
operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}})
indices = s32[5,6] parameter(0)
gather = s32[5,4,6] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT reshape = s32[5,2,2,2,3] reshape(gather)
}
)";
@@ -562,11 +562,11 @@ ENTRY main {
{{1,2},{3,4},{5,6},{7,8},{9,10}}})
indices = s32[7] parameter(0)
gather = s32[3,2,7] gather(operand, indices),
- output_window_dims={0,1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0,1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3,1,2}
+ slice_sizes={3,1,2}
ROOT reshape = s32[6,7] reshape(gather)
}
)";
@@ -594,11 +594,11 @@ ENTRY main {
{{1},{2},{3},{4}}})
indices = s32[5,6] parameter(0)
gather = s32[5,4,6,1] gather(operand, indices),
- output_window_dims={1,3},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1,3},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1,4,1}
+ slice_sizes={1,4,1}
ROOT reshape = s32[5,2,2,2,3,1] reshape(gather)
}
)";
@@ -623,11 +623,11 @@ ENTRY main {
operand = f32[3,4] constant(f32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
indices = s32[5] parameter(0)
gather = f32[5,4] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT tanh = f32[5,4] tanh(gather)
}
)";
@@ -650,11 +650,11 @@ ENTRY main {
constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
indices = s32[5] parameter(0)
gather = s32[5,4] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT add = s32[5,4] add(gather, constant_broadcasted)
}
)";
@@ -678,11 +678,11 @@ ENTRY main {
constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
indices = s32[5] parameter(0)
gather = s32[5,4] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT sub = s32[5,4] subtract(gather, constant_broadcasted)
}
)";
@@ -706,11 +706,11 @@ ENTRY main {
constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
indices = s32[5] parameter(0)
gather = s32[5,4] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT sub = s32[5,4] subtract(constant_broadcasted, gather)
}
)";
@@ -733,11 +733,11 @@ ENTRY main {
constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={1}
indices = s32[5] parameter(0)
gather = s32[5,4] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT add = s32[5,4] add(gather, constant_broadcasted)
}
)";
@@ -760,11 +760,11 @@ ENTRY main {
constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={0}
indices = s32[5] parameter(0)
gather = s32[5,4] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT add = s32[5,4] add(gather, constant_broadcasted)
}
)";
@@ -808,11 +808,11 @@ ENTRY main {
dot_rhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
indices = s32[5] parameter(0)
dot_lhs = s32[5,4] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
)";
@@ -835,11 +835,11 @@ ENTRY main {
dot_rhs_constant = s32[3,3] constant(s32[3,3]{{1,2,3},{4,5,6},{7,8,9}})
indices = s32[5] parameter(0)
dot_lhs = s32[3,5] gather(gather_operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3,1}
+ slice_sizes={3,1}
ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={0}, rhs_contracting_dims={0}
}
)";
@@ -863,11 +863,11 @@ ENTRY main {
dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
indices = s32[5] parameter(0)
dot_rhs = s32[3,5] gather(gather_operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3,1}
+ slice_sizes={3,1}
ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
)";
@@ -892,11 +892,11 @@ ENTRY main {
dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
indices = s32[5] parameter(0)
dot_rhs = s32[5,3] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,3}
+ slice_sizes={1,3}
ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={1}
}
)";
@@ -921,11 +921,11 @@ ENTRY main {
dot_lhs_constant = s32[2,2,3] constant(s32[2,2,3]{{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}})
indices = s32[4] parameter(0)
dot_rhs = s32[2,3,4] gather(gather_operand, indices),
- output_window_dims={0,1},
- elided_window_dims={2},
- gather_dims_to_operand_dims={2},
+ offset_dims={0,1},
+ collapsed_slice_dims={2},
+ start_index_map={2},
index_vector_dim=1,
- window_bounds={2,3,1}
+ slice_sizes={2,3,1}
ROOT dot = s32[2,2,4] dot(dot_lhs_constant, dot_rhs),
lhs_contracting_dims={2}, rhs_contracting_dims={1},
lhs_batch_dims={0}, rhs_batch_dims={0}
@@ -952,11 +952,11 @@ ENTRY main {
dot_rhs_constant = s32[2,3] constant(s32[2,3]{{1,2,3},{4,5,6}})
indices = s32[2] parameter(0)
dot_lhs = s32[3,2] gather(gather_operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3,1}
+ slice_sizes={3,1}
ROOT dot = s32[3,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
)";
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index ec5743a777..cc1ec1704e 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -2492,201 +2492,196 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
static Status ValidateGatherDimensionNumbers(
const Shape& input_shape,
- tensorflow::gtl::ArraySlice<int64> gather_indices_shape,
+ tensorflow::gtl::ArraySlice<int64> start_indices_shape,
const GatherDimensionNumbers& dim_numbers) {
- if (!c_is_sorted(dim_numbers.output_window_dims())) {
+ if (!c_is_sorted(dim_numbers.offset_dims())) {
return InvalidArgument(
"Output window dimensions in gather op must be ascending; got: %s.",
- Join(dim_numbers.output_window_dims(), ", ").c_str());
+ Join(dim_numbers.offset_dims(), ", ").c_str());
}
- if (c_adjacent_find(dim_numbers.output_window_dims()) !=
- dim_numbers.output_window_dims().end()) {
+ if (c_adjacent_find(dim_numbers.offset_dims()) !=
+ dim_numbers.offset_dims().end()) {
return InvalidArgument(
"Output window dimensions in gather op must not repeat; got: %s.",
- Join(dim_numbers.output_window_dims(), ", ").c_str());
+ Join(dim_numbers.offset_dims(), ", ").c_str());
}
- const int64 output_window_dim_count = dim_numbers.output_window_dims_size();
+ const int64 output_offset_dim_count = dim_numbers.offset_dims_size();
const int64 output_shape_rank =
- output_window_dim_count + gather_indices_shape.size() - 1;
+ output_offset_dim_count + start_indices_shape.size() - 1;
- for (int i = 0; i < dim_numbers.output_window_dims_size(); ++i) {
- int64 window_index = dim_numbers.output_window_dims(i);
- if (window_index < 0 || window_index >= output_shape_rank) {
+ for (int i = 0; i < dim_numbers.offset_dims_size(); ++i) {
+ int64 offset_dim = dim_numbers.offset_dims(i);
+ if (offset_dim < 0 || offset_dim >= output_shape_rank) {
return InvalidArgument(
- "Window index %d in gather op is out of bounds; got %lld, but should "
+ "Offset dimension %d in gather op is out of bounds; got %lld, but "
+ "should "
"have been in [0,%lld).",
- i, window_index, output_shape_rank);
+ i, offset_dim, output_shape_rank);
}
}
- if (dim_numbers.gather_dims_to_operand_dims_size() !=
- gather_indices_shape[dim_numbers.index_vector_dim()]) {
+ if (dim_numbers.start_index_map_size() !=
+ start_indices_shape[dim_numbers.index_vector_dim()]) {
return InvalidArgument(
- "Gather op has %d elements in gather_dims_to_operand_dims and the "
- "bound of dimension index_vector_dim=%lld of gather_indices is "
+ "Gather op has %d elements in start_index_map and the "
+ "bound of dimension index_vector_dim=%lld of start_indices is "
"%lld. These two numbers must be equal.",
- dim_numbers.gather_dims_to_operand_dims_size(),
- dim_numbers.index_vector_dim(),
- gather_indices_shape[dim_numbers.index_vector_dim()]);
+ dim_numbers.start_index_map_size(), dim_numbers.index_vector_dim(),
+ start_indices_shape[dim_numbers.index_vector_dim()]);
}
- for (int i = 0; i < dim_numbers.gather_dims_to_operand_dims_size(); i++) {
- int64 gather_dim_to_input_dim = dim_numbers.gather_dims_to_operand_dims(i);
- if (gather_dim_to_input_dim < 0 ||
- gather_dim_to_input_dim >= input_shape.dimensions_size()) {
+ for (int i = 0; i < dim_numbers.start_index_map_size(); i++) {
+ int64 operand_dim_for_start_index_i = dim_numbers.start_index_map(i);
+ if (operand_dim_for_start_index_i < 0 ||
+ operand_dim_for_start_index_i >= input_shape.dimensions_size()) {
return InvalidArgument(
- "Invalid gather_dims_to_operand_dims mapping; domain is [0, %d), "
- "got: %d->%lld.",
- input_shape.dimensions_size(), i, gather_dim_to_input_dim);
+ "Invalid start_index_map; domain is [0, %d), got: %d->%lld.",
+ input_shape.dimensions_size(), i, operand_dim_for_start_index_i);
}
}
- std::vector<int64> sorted_gather_dims_to_operand_dims(
- dim_numbers.gather_dims_to_operand_dims().begin(),
- dim_numbers.gather_dims_to_operand_dims().end());
+ std::vector<int64> sorted_start_index_map(
+ dim_numbers.start_index_map().begin(),
+ dim_numbers.start_index_map().end());
- c_sort(sorted_gather_dims_to_operand_dims);
+ c_sort(sorted_start_index_map);
- if (c_adjacent_find(sorted_gather_dims_to_operand_dims) !=
- sorted_gather_dims_to_operand_dims.end()) {
+ if (c_adjacent_find(sorted_start_index_map) != sorted_start_index_map.end()) {
return InvalidArgument(
- "Repeated dimensions are not allowed in gather_dims_to_operand_dims; "
+ "Repeated dimensions are not allowed in start_index_map; "
"got: %s.",
- Join(dim_numbers.gather_dims_to_operand_dims(), ", ").c_str());
+ Join(dim_numbers.start_index_map(), ", ").c_str());
}
- for (int64 elided_dim : dim_numbers.elided_window_dims()) {
- if (elided_dim < 0 || elided_dim >= input_shape.dimensions_size()) {
+ for (int64 collapsed_dim : dim_numbers.collapsed_slice_dims()) {
+ if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) {
return InvalidArgument(
- "Invalid elided_window_dims set in gather op; valid range is [0, "
+ "Invalid collapsed_slice_dims set in gather op; valid range is [0, "
"%d), got: %lld.",
- input_shape.dimensions_size(), elided_dim);
+ input_shape.dimensions_size(), collapsed_dim);
}
}
- if (!c_is_sorted(dim_numbers.elided_window_dims())) {
+ if (!c_is_sorted(dim_numbers.collapsed_slice_dims())) {
return InvalidArgument(
- "elided_window_dims in gather op must be sorted; got: %s",
- Join(dim_numbers.elided_window_dims(), ", ").c_str());
+ "collapsed_slice_dims in gather op must be sorted; got: %s",
+ Join(dim_numbers.collapsed_slice_dims(), ", ").c_str());
}
- if (c_adjacent_find(dim_numbers.elided_window_dims()) !=
- dim_numbers.elided_window_dims().end()) {
+ if (c_adjacent_find(dim_numbers.collapsed_slice_dims()) !=
+ dim_numbers.collapsed_slice_dims().end()) {
return InvalidArgument(
- "Repeated dimensions not allowed in elided_window_dims in gather op; "
+ "Repeated dimensions not allowed in collapsed_slice_dims in gather op; "
"got: %s.",
- Join(dim_numbers.elided_window_dims(), ", ").c_str());
+ Join(dim_numbers.collapsed_slice_dims(), ", ").c_str());
}
return Status::OK();
}
/*static*/ StatusOr<Shape> ShapeInference::InferGatherShape(
- const Shape& input_shape, const Shape& gather_indices_shape,
+ const Shape& input_shape, const Shape& start_indices_shape,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds) {
+ tensorflow::gtl::ArraySlice<int64> slice_sizes) {
TF_RETURN_IF_ERROR(
ExpectArray(input_shape, "input tensor operand gather op"));
TF_RETURN_IF_ERROR(
- ExpectArray(gather_indices_shape, "gather indices operand of gather op"));
+ ExpectArray(start_indices_shape, "gather indices operand of gather op"));
- if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) {
+ if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
return InvalidArgument(
"Gather indices parameter must be an integral tensor; got %s.",
- ShapeUtil::HumanString(gather_indices_shape).c_str());
+ ShapeUtil::HumanString(start_indices_shape).c_str());
}
// We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if
// index_vector_dim is rank(P). The bounds of this expanded shape is
- // stored in expanded_gather_indices_shape.
+ // stored in expanded_start_indices_shape.
- if (gather_indices_shape.dimensions_size() <
+ if (start_indices_shape.dimensions_size() <
gather_dim_numbers.index_vector_dim() ||
gather_dim_numbers.index_vector_dim() < 0) {
return InvalidArgument(
- "Gather index leaf dimension must be within [0, rank(gather_indices) + "
- "1). rank(gather_indices) is %d and gather index leaf dimension is "
+ "Gather index leaf dimension must be within [0, rank(start_indices) + "
+ "1). rank(start_indices) is %d and gather index leaf dimension is "
"%lld.",
- gather_indices_shape.dimensions_size(),
+ start_indices_shape.dimensions_size(),
gather_dim_numbers.index_vector_dim());
}
- std::vector<int64> expanded_gather_indices_shape;
- expanded_gather_indices_shape.reserve(gather_indices_shape.dimensions_size());
- c_copy(gather_indices_shape.dimensions(),
- std::back_inserter(expanded_gather_indices_shape));
- if (expanded_gather_indices_shape.size() ==
+ std::vector<int64> expanded_start_indices_shape;
+ expanded_start_indices_shape.reserve(start_indices_shape.dimensions_size());
+ c_copy(start_indices_shape.dimensions(),
+ std::back_inserter(expanded_start_indices_shape));
+ if (expanded_start_indices_shape.size() ==
gather_dim_numbers.index_vector_dim()) {
- expanded_gather_indices_shape.push_back(1);
+ expanded_start_indices_shape.push_back(1);
}
TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers(
- input_shape, expanded_gather_indices_shape, gather_dim_numbers));
+ input_shape, expanded_start_indices_shape, gather_dim_numbers));
- if (window_bounds.size() != input_shape.dimensions_size()) {
+ if (slice_sizes.size() != input_shape.dimensions_size()) {
return InvalidArgument(
- "Gather op must have one window bound for every input dimension; got: "
- "len(window_bounds)=%lu, input_shape.rank=%d.",
- window_bounds.size(), input_shape.dimensions_size());
+ "Gather op must have one slice size for every input dimension; got: "
+ "len(slice_sizes)=%lu, input_shape.rank=%d.",
+ slice_sizes.size(), input_shape.dimensions_size());
}
- if (window_bounds.size() !=
- gather_dim_numbers.output_window_dims_size() +
- gather_dim_numbers.elided_window_dims_size()) {
+ if (slice_sizes.size() !=
+ gather_dim_numbers.offset_dims_size() +
+ gather_dim_numbers.collapsed_slice_dims_size()) {
return InvalidArgument(
- "All components of the window index in a gather op must either be a "
- "output window index or explicitly elided; got len(window_bounds)=%lu, "
- "output_window_bounds=%s, elided_window_bounds=%s.",
- window_bounds.size(),
- Join(gather_dim_numbers.output_window_dims(), ",").c_str(),
- Join(gather_dim_numbers.elided_window_dims(), ",").c_str());
+ "All components of the offset index in a gather op must either be a "
+ "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, "
+ "output_slice_sizes=%s, collapsed_slice_dims=%s.",
+ slice_sizes.size(), Join(gather_dim_numbers.offset_dims(), ",").c_str(),
+ Join(gather_dim_numbers.collapsed_slice_dims(), ",").c_str());
}
- for (int i = 0; i < window_bounds.size(); i++) {
- int64 window_bound = window_bounds[i];
- int64 corresponding_input_bound = input_shape.dimensions(i);
- if (window_bound < 0 || window_bound > corresponding_input_bound) {
+ for (int i = 0; i < slice_sizes.size(); i++) {
+ int64 slice_size = slice_sizes[i];
+ int64 corresponding_input_size = input_shape.dimensions(i);
+ if (slice_size < 0 || slice_size > corresponding_input_size) {
return InvalidArgument(
- "Window bound at index %d in gather op is out of range, must be "
- "within "
- "[0, %lld), got %lld.",
- i, corresponding_input_bound + 1, window_bound);
+ "Slice size at index %d in gather op is out of range, must be "
+ "within [0, %lld), got %lld.",
+ i, corresponding_input_size + 1, slice_size);
}
}
- for (int i = 0; i < gather_dim_numbers.elided_window_dims_size(); i++) {
- if (window_bounds[gather_dim_numbers.elided_window_dims(i)] != 1) {
+ for (int i = 0; i < gather_dim_numbers.collapsed_slice_dims_size(); i++) {
+ if (slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)] != 1) {
return InvalidArgument(
- "Gather op can only elide window indices with bound 1, but bound is "
+ "Gather op can only collapse slice dims with bound 1, but bound is "
"%lld for index %lld at position %d.",
- window_bounds[gather_dim_numbers.elided_window_dims(i)],
- gather_dim_numbers.elided_window_dims(i), i);
+ slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)],
+ gather_dim_numbers.collapsed_slice_dims(i), i);
}
}
- int64 result_rank = gather_dim_numbers.output_window_dims_size() +
- (expanded_gather_indices_shape.size() - 1);
- int64 window_dims_seen = 0;
+ int64 result_rank = gather_dim_numbers.offset_dims_size() +
+ (expanded_start_indices_shape.size() - 1);
+ int64 offset_dims_seen = 0;
int64 gather_dims_seen = 0;
std::vector<int64> output_dim_bounds;
output_dim_bounds.reserve(result_rank);
for (int64 i = 0; i < result_rank; i++) {
int64 current_bound;
- bool is_window_index =
- c_binary_search(gather_dim_numbers.output_window_dims(), i);
+ bool is_window_index = c_binary_search(gather_dim_numbers.offset_dims(), i);
if (is_window_index) {
- while (c_binary_search(gather_dim_numbers.elided_window_dims(),
- window_dims_seen)) {
- window_dims_seen++;
+ while (c_binary_search(gather_dim_numbers.collapsed_slice_dims(),
+ offset_dims_seen)) {
+ offset_dims_seen++;
}
- current_bound = window_bounds[window_dims_seen++];
+ current_bound = slice_sizes[offset_dims_seen++];
} else {
if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) {
gather_dims_seen++;
}
- current_bound = expanded_gather_indices_shape[gather_dims_seen++];
+ current_bound = expanded_start_indices_shape[gather_dims_seen++];
}
output_dim_bounds.push_back(current_bound);
@@ -2837,25 +2832,25 @@ Status ValidateScatterDimensionNumbers(
scatter_dim_numbers));
int64 inserted_dims_seen = 0;
- std::vector<int64> max_update_window_bounds;
+ std::vector<int64> max_update_slice_sizes;
for (int i = 0; i < operand_shape.dimensions_size(); ++i) {
if (inserted_dims_seen < scatter_dim_numbers.inserted_window_dims_size() &&
scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) {
++inserted_dims_seen;
} else {
- max_update_window_bounds.push_back(operand_shape.dimensions(i));
+ max_update_slice_sizes.push_back(operand_shape.dimensions(i));
}
}
for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) {
auto update_window_dim = scatter_dim_numbers.update_window_dims(i);
if (updates_shape.dimensions(update_window_dim) >
- max_update_window_bounds[i]) {
+ max_update_slice_sizes[i]) {
return InvalidArgument(
"Bounds of the window dimensions of updates must not exceed the "
"bounds of the corresponding dimensions of operand. For dimension "
"%lld, updates bound is %lld, operand bound is %lld.",
update_window_dim, updates_shape.dimensions(update_window_dim),
- max_update_window_bounds[i]);
+ max_update_slice_sizes[i]);
}
}
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index bfd79a4433..4974ac9916 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -276,9 +276,9 @@ class ShapeInference {
// with the given input shape, gather indices shape and gather dimension
// numbers.
static StatusOr<Shape> InferGatherShape(
- const Shape& input_shape, const Shape& gather_indices_shape,
+ const Shape& input_shape, const Shape& start_indices_shape,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
// Helper that validates the given input shape, scatter indices shape, updates
// shape, and scatter dimension numbers that constitute a scatter operation,
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index a73fa181cd..4ed8fc6b86 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -1654,11 +1654,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGather) {
ShapeInference::InferGatherShape(
matrix_64_48_, s64_vector_32_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{0},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1},
+ /*offset_dims=*/{0},
+ /*collapsed_slice_dims=*/{1},
+ /*start_index_map=*/{1},
/*index_vector_dim=*/1),
- /*window_bounds=*/{64, 1}));
+ /*slice_sizes=*/{64, 1}));
EXPECT_TRUE(
ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32})))
<< ShapeUtil::HumanString(gather_shape);
@@ -1669,11 +1669,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherV2) {
ShapeInference::InferGatherShape(
matrix_64_48_, s64_vector_32_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{1},
- /*elided_window_dims=*/{0},
- /*gather_dims_to_operand_dims=*/{0},
+ /*offset_dims=*/{1},
+ /*collapsed_slice_dims=*/{0},
+ /*start_index_map=*/{0},
/*index_vector_dim=*/1),
- /*window_bounds=*/{1, 48}));
+ /*slice_sizes=*/{1, 48}));
EXPECT_TRUE(
ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48})))
<< ShapeUtil::HumanString(gather_shape);
@@ -1684,11 +1684,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherNd) {
ShapeInference::InferGatherShape(
matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4},
- /*elided_window_dims=*/{0},
- /*gather_dims_to_operand_dims=*/{0},
+ /*offset_dims=*/{4},
+ /*collapsed_slice_dims=*/{0},
+ /*start_index_map=*/{0},
/*index_vector_dim=*/4),
- /*window_bounds=*/{1, 48}));
+ /*slice_sizes=*/{1, 48}));
EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48})))
<< ShapeUtil::HumanString(gather_shape);
@@ -1700,11 +1700,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26}));
+ /*slice_sizes=*/{30, 29, 28, 27, 26}));
EXPECT_TRUE(ShapeUtil::Equal(
gather_shape,
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26})))
@@ -1717,11 +1717,11 @@ TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) {
ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/2),
- /*window_bounds=*/{30, 29, 28, 27, 26}));
+ /*slice_sizes=*/{30, 29, 28, 27, 26}));
EXPECT_TRUE(ShapeUtil::Equal(
gather_shape,
@@ -1735,11 +1735,11 @@ TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/0),
- /*window_bounds=*/{30, 29, 28, 27, 26}));
+ /*slice_sizes=*/{30, 29, 28, 27, 26}));
EXPECT_TRUE(ShapeUtil::Equal(
gather_shape,
@@ -1749,16 +1749,15 @@ TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
TEST_F(ScatterGatherShapeInferenceTest, NoOutputGatherDims) {
// This is equivalent to a dynamic slice.
- TF_ASSERT_OK_AND_ASSIGN(
- Shape gather_shape,
- ShapeInference::InferGatherShape(
- f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
- HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{0, 1, 2, 3, 4},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
- /*index_vector_dim=*/0),
- /*window_bounds=*/{30, 29, 28, 27, 26}));
+ TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
+ ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
+ HloGatherInstruction::MakeGatherDimNumbers(
+ /*offset_dims=*/{0, 1, 2, 3, 4},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/0),
+ /*slice_sizes=*/{30, 29, 28, 27, 26}));
EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26})))
@@ -1772,11 +1771,11 @@ TEST_F(ScatterGatherShapeInferenceTest, ScalarGatherIndices) {
ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{0, 1, 2, 3},
- /*elided_window_dims=*/{0},
- /*gather_dims_to_operand_dims=*/{0},
+ /*offset_dims=*/{0, 1, 2, 3},
+ /*collapsed_slice_dims=*/{0},
+ /*start_index_map=*/{0},
/*index_vector_dim=*/0),
- /*window_bounds=*/{1, 30, 29, 28, 27}));
+ /*slice_sizes=*/{1, 30, 29, 28, 27}));
EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
ShapeUtil::MakeShape(F32, {30, 29, 28, 27})))
@@ -1787,11 +1786,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TupleShapedTensorInput) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
tuple_shape_, s64_vector_32_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{0},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1},
+ /*offset_dims=*/{0},
+ /*collapsed_slice_dims=*/{1},
+ /*start_index_map=*/{1},
/*index_vector_dim=*/1),
- /*window_bounds=*/{64, 1});
+ /*slice_sizes=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
HasSubstr("Expected array argument for input"))
@@ -1802,11 +1801,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
s64_vector_32_, tuple_shape_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{0},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1},
+ /*offset_dims=*/{0},
+ /*collapsed_slice_dims=*/{1},
+ /*start_index_map=*/{1},
/*index_vector_dim=*/0),
- /*window_bounds=*/{64, 1});
+ /*slice_sizes=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
HasSubstr("Expected array argument for gather indices"))
@@ -1817,11 +1816,11 @@ TEST_F(ScatterGatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
s64_vector_32_, vector_32_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{0},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1},
+ /*offset_dims=*/{0},
+ /*collapsed_slice_dims=*/{1},
+ /*start_index_map=*/{1},
/*index_vector_dim=*/0),
- /*window_bounds=*/{64, 1});
+ /*slice_sizes=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
HasSubstr("Gather indices parameter must be an integral tensor"))
@@ -1833,11 +1832,11 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 8, 7},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 8, 7},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
statusor.status().error_message(),
@@ -1850,11 +1849,11 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 7},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 7},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
statusor.status().error_message(),
@@ -1867,14 +1866,14 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 99, 100, 101},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 99, 100, 101},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Window index 2 in gather op is out of bounds"))
+ HasSubstr("Offset dimension 2 in gather op is out of bounds"))
<< statusor.status();
}
@@ -1883,14 +1882,14 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 9},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 9},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Window index 4 in gather op is out of bounds"))
+ HasSubstr("Offset dimension 4 in gather op is out of bounds"))
<< statusor.status();
}
@@ -1899,16 +1898,16 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{4},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{4},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
statusor.status().error_message(),
- HasSubstr("All components of the window index in a gather op must either "
- "be a output window index or explicitly elided"))
+ HasSubstr("All components of the offset index in a gather op must either "
+ "be a offset dimension or explicitly collapsed"))
<< statusor.status();
}
@@ -1917,14 +1916,14 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{0, 1, 2, 3, 19},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{0, 1, 2, 3, 19},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Invalid elided_window_dims set in gather op; valid "
+ HasSubstr("Invalid collapsed_slice_dims set in gather op; valid "
"range is [0, 5), got: 19"))
<< statusor.status();
}
@@ -1934,16 +1933,15 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{0, 1, 2, 3, 3},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{0, 1, 2, 3, 3},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
- EXPECT_THAT(
- statusor.status().error_message(),
- HasSubstr(
- "Repeated dimensions not allowed in elided_window_dims in gather op"))
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Repeated dimensions not allowed in "
+ "collapsed_slice_dims in gather op"))
<< statusor.status();
}
@@ -1952,17 +1950,16 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
- EXPECT_THAT(
- statusor.status().error_message(),
- HasSubstr("Gather op has 4 elements in gather_dims_to_operand_dims and "
- "the bound of dimension index_vector_dim=4 of "
- "gather_indices is 5. These two numbers must be equal."))
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Gather op has 4 elements in start_index_map and "
+ "the bound of dimension index_vector_dim=4 of "
+ "start_indices is 5. These two numbers must be equal."))
<< statusor.status();
}
@@ -1971,16 +1968,14 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 7},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
- EXPECT_THAT(
- statusor.status().error_message(),
- HasSubstr("Invalid gather_dims_to_operand_dims mapping; domain is "
- "[0, 5), got: 4->7"))
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Invalid start_index_map; domain is [0, 5), got: 4->7"))
<< statusor.status();
}
@@ -1989,16 +1984,15 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 3},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
statusor.status().error_message(),
- HasSubstr(
- "Repeated dimensions are not allowed in gather_dims_to_operand_dims"))
+ HasSubstr("Repeated dimensions are not allowed in start_index_map"))
<< statusor.status();
}
@@ -2007,14 +2001,14 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{2, 1},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{2, 1},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{1, 1, 28, 27, 26});
+ /*slice_sizes=*/{1, 1, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("elided_window_dims in gather op must be sorted"))
+ HasSubstr("collapsed_slice_dims in gather op must be sorted"))
<< statusor.status();
}
@@ -2023,15 +2017,15 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7},
- /*elided_window_dims=*/{2},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7},
+ /*collapsed_slice_dims=*/{2},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 1, 300, 26});
+ /*slice_sizes=*/{30, 29, 1, 300, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Window bound at index 3 in gather op is out of range, "
- "must be within [0, 48), got 300"))
+ HasSubstr("Slice size at index 3 in gather op is out of range, "
+ "must be within [0, 48), got 300."))
<< statusor.status();
}
@@ -2040,16 +2034,15 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 26});
+ /*slice_sizes=*/{30, 29, 28, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
statusor.status().error_message(),
- HasSubstr(
- "Gather op must have one window bound for every input dimension"))
+ HasSubstr("Gather op must have one slice size for every input dimension"))
<< statusor.status();
}
@@ -2058,15 +2051,15 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7},
+ /*collapsed_slice_dims=*/{1},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 26, 20});
+ /*slice_sizes=*/{30, 29, 28, 26, 20});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Gather op can only elide window indices with bound 1, "
- "but bound is 29 for index 1 at position 0"))
+ HasSubstr("Gather op can only collapse slice dims with bound 1, "
+ "but bound is 29 for index 1 at position 0."))
<< statusor.status();
}
@@ -2074,16 +2067,16 @@ TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/32),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
HasSubstr("Gather index leaf dimension must be within [0, "
- "rank(gather_indices) + 1)"))
+ "rank(start_indices) + 1)"))
<< statusor.status();
}
diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc
index b77bece85a..f866ed6519 100644
--- a/tensorflow/compiler/xla/tests/gather_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc
@@ -30,8 +30,8 @@ using tensorflow::gtl::nullopt;
class GatherOperationTest : public HloTestBase {
protected:
void RunTest(const string& hlo_text, Literal* operand,
- Literal* gather_indices) {
- RunTest(hlo_text, {operand, gather_indices});
+ Literal* start_indices) {
+ RunTest(hlo_text, {operand, start_indices});
}
void RunTest(const string& hlo_text,
@@ -52,18 +52,17 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[2,3] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1, 3}
+ slice_sizes={1, 3}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherV2) {
@@ -74,18 +73,17 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[3,2] gather(operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherMultipleBatchDims) {
@@ -96,18 +94,18 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,3,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_0) {
@@ -118,18 +116,18 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2,2] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
- output_window_dims={},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=2,
- window_bounds={1, 1}
+ slice_sizes={1, 1}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_1) {
@@ -140,18 +138,18 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2,2] parameter(1)
ROOT gather = s32[2,1,1,2] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=2,
- window_bounds={1, 1}
+ slice_sizes={1, 1}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNd) {
@@ -162,20 +160,20 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdNonDefaultIndexVectorDim) {
@@ -186,20 +184,20 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, DynamicSlice) {
@@ -210,18 +208,17 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[1,1] gather(operand, indices),
- output_window_dims={0,1},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={0,1},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, BatchDynamicSlice) {
@@ -232,18 +229,18 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,1,1] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, ZeroDimBounds) {
@@ -254,17 +251,16 @@ ENTRY main {
operand = s32[3,0] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[2,0] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1, 0}
+ slice_sizes={1, 0}
}
)";
std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) {
@@ -278,19 +274,19 @@ ENTRY main {
operand = s32[3,3]{1,0} parameter(0)
indices = s32[6,2]{1,0} parameter(1)
gather = s32[6,1,1]{2,1,0} gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1}
+ slice_sizes={1,1}
ROOT result = s32[6]{0} reshape(gather)
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>(
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, OutOfBoundsUnsignedIndex) {
@@ -304,19 +300,19 @@ ENTRY main {
operand = s32[3,3]{1,0} parameter(0)
indices = u32[6,2]{1,0} parameter(1)
gather = s32[6,1,1]{2,1,0} gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1}
+ slice_sizes={1,1}
ROOT result = s32[6]{0} reshape(gather)
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<uint32>(
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<uint32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, NegativeIndex) {
@@ -330,19 +326,19 @@ ENTRY main {
operand = s32[3,3]{1,0} parameter(0)
indices = s32[6,2]{1,0} parameter(1)
gather = s32[6,1,1]{2,1,0} gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1}
+ slice_sizes={1,1}
ROOT result = s32[6]{0} reshape(gather)
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>(
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>(
{{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, NegativeIndexIntoUnsignedOperand) {
@@ -356,19 +352,19 @@ ENTRY main {
operand = u32[3,3]{1,0} parameter(0)
indices = s32[6,2]{1,0} parameter(1)
gather = u32[6,1,1]{2,1,0} gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1}
+ slice_sizes={1,1}
ROOT result = u32[6]{0} reshape(gather)
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<uint32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>(
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>(
{{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, OneScalarIndex) {
@@ -379,17 +375,17 @@ ENTRY main {
operand = s32[2,3,2]{2,1,0} parameter(0)
index = s32[] parameter(1)
ROOT gather = s32[1,3,2]{2,1,0} gather(operand, index),
- output_window_dims={0,1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0},
+ offset_dims={0,1,2},
+ collapsed_slice_dims={},
+ start_index_map={0},
index_vector_dim=0,
- window_bounds={1,3,2}
+ slice_sizes={1,3,2}
}
)";
std::unique_ptr<Literal> operand = LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR0<int32>(1);
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR0<int32>(1);
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, ScalarResult) {
@@ -400,16 +396,16 @@ ENTRY main {
operand = s32[4]{0} parameter(0)
index = s32[] parameter(1)
ROOT gather = s32[] gather(operand, index),
- output_window_dims={},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=0,
- window_bounds={1}
+ slice_sizes={1}
}
)";
std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR0<int32>(1);
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR0<int32>(1);
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, ZeroSizedResult) {
@@ -420,17 +416,17 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[0] parameter(1)
ROOT gather = s32[0,3] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1, 3}
+ slice_sizes={1, 3}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR1<int32>({});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({});
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherV2) {
@@ -441,11 +437,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
gather = s32[3,2] gather(operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
one = s32[] constant(1)
one_broadcasted = s32[3,2] broadcast(one), dimensions={}
ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted)
@@ -453,9 +449,8 @@ ENTRY main {
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherMultipleBatchDims) {
@@ -466,11 +461,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,3,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
one = s32[] constant(1)
one_broadcasted = s32[2,3,2] broadcast(one), dimensions={}
ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted)
@@ -478,9 +473,9 @@ ENTRY main {
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNdMultipleBatchDims) {
@@ -491,11 +486,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
- output_window_dims={},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=2,
- window_bounds={1, 1}
+ slice_sizes={1, 1}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
@@ -503,9 +498,9 @@ ENTRY main {
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNd) {
@@ -516,11 +511,11 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
@@ -530,9 +525,9 @@ ENTRY main {
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest,
@@ -544,11 +539,11 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
@@ -558,9 +553,9 @@ ENTRY main {
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, FusedDynamicSlice) {
@@ -571,11 +566,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
gather = s32[1,1] gather(operand, indices),
- output_window_dims={0,1},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={0,1},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
one = s32[] constant(1)
one_broadcasted = s32[1,1] broadcast(one), dimensions={}
ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted)
@@ -583,9 +578,8 @@ ENTRY main {
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, FusedBatchDynamicSlice) {
@@ -596,11 +590,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,1,1] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
one = s32[] constant(1)
one_broadcasted = s32[2,1,1] broadcast(one), dimensions={}
ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted)
@@ -608,9 +602,9 @@ ENTRY main {
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
class GatherClientLibraryTest : public ClientLibraryTestBase {};
@@ -622,11 +616,11 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
// operand = s32[3,3] parameter(0)
// indices = s32[2] parameter(1)
// ROOT gather = s32[2,3] gather(operand, indices),
- // output_window_dims={1},
- // elided_window_dims={0},
- // gather_dims_to_operand_dims={0},
+ // offset_dims={1},
+ // collapsed_slice_dims={0},
+ // start_index_map={0},
// index_vector_dim=1,
- // window_bounds={1, 3}
+ // slice_sizes={1, 3}
// }
XlaBuilder builder("gather_basic");
@@ -637,9 +631,9 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
auto operand = Parameter(&builder, 0, operand_shape, "operand");
auto indices = Parameter(&builder, 1, indices_shape, "indices");
GatherDimensionNumbers dim_numbers;
- dim_numbers.add_output_window_dims(1);
- dim_numbers.add_elided_window_dims(0);
- dim_numbers.add_gather_dims_to_operand_dims(0);
+ dim_numbers.add_offset_dims(1);
+ dim_numbers.add_collapsed_slice_dims(0);
+ dim_numbers.add_start_index_map(0);
dim_numbers.set_index_vector_dim(1);
Gather(operand, indices, dim_numbers, {1, 3});
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 4c35e93d38..27aa94c2cb 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -424,25 +424,25 @@ message GatherDimensionNumbers {
// "Window indices" is a term for a set of indices that index into the
// interior of a dynamic-slice from the input tensor, the starting indices for
// which were computed from output_gather_dims (see the operation semantic for
- // how this is defined) and the gather_indices tensor.
+ // how this is defined) and the start_indices tensor.
//
// The window indices for a specific output index Out is computed as:
//
// i = 0
// for (k : [0, input_tensor_shape.rank))
// window_indices[k] =
- // if k in elided_window_dims
+ // if k in collapsed_slice_dims
// then 0
- // else Out[output_window_dims[i++]]
- repeated int64 output_window_dims = 1;
- repeated int64 elided_window_dims = 2;
+ // else Out[offset_dims[i++]]
+ repeated int64 offset_dims = 1;
+ repeated int64 collapsed_slice_dims = 2;
- // This is interpreted as a map from i to gather_dims_to_operand_dims[i]. It
- // transforms the gather index looked up from the gather_indices tensor into
+ // This is interpreted as a map from i to start_index_map[i]. It
+ // transforms the gather index looked up from the start_indices tensor into
// the starting index in the input space.
- repeated int64 gather_dims_to_operand_dims = 3;
+ repeated int64 start_index_map = 3;
- // The dimension in the gather_indices input that contains the starting
+ // The dimension in the start_indices input that contains the starting
// indices.
int64 index_vector_dim = 4;
}
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
index 16dd3c5bf3..2de30d1b3d 100644
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ b/tensorflow/docs_src/performance/xla/operation_semantics.md
@@ -1138,7 +1138,7 @@ array with the same shape. It is allowed for `operand` to be a scalar (rank 0).
## Gather
The XLA gather operation stitches together several slices (each slice at a
-potentially different runtime offset) of an input tensor into an output tensor.
+potentially different runtime offset) of an input array.
### General Semantics
@@ -1146,151 +1146,141 @@ See also
[`XlaBuilder::Gather`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
For a more intuitive description, see the "Informal Description" section below.
-<b> `gather(operand, gather_indices, output_window_dims, elided_window_dims, window_bounds, gather_dims_to_operand_dims)` </b>
+<b> `gather(operand, start_indices, offset_dims, collapsed_slice_dims, slice_sizes, start_index_map)` </b>
|Arguments | Type | Semantics |
|----------------- | ----------------------- | --------------------------------|
-|`operand` | `XlaOp` | The tensor we’re gathering |
+|`operand` | `XlaOp` | The array we’re gathering |
: : : from. :
-|`gather_indices` | `XlaOp` | Tensor containing the starting |
-: : : indices of the slices we're :
-: : : stitching together into the :
-: : : output tensor. :
-|`index_vector_dim` | `int64` | The dimension in |
-: : : `gather_indices` that contains :
-: : : the starting indices. :
-|`output_window_dims` | `ArraySlice<int64>` | The set of dimensions in the |
-: : : output shape that are _window :
-: : : dimensions_ (defined below). :
-: : : Not all window dimensions may :
-: : : be present in the output shape. :
-|`elided_window_dims` | `ArraySlice<int64>` | The set of _window dimensions_ |
-: : : that are not present in the output shape. :
-: : : `window_bounds[i]` must be `1` for all `i` :
-: : : in `elided_window_dims`. :
-|`window_bounds` | `ArraySlice<int64>` | `window_bounds[i]` is the bounds |
-: : : for window dimension `i`. This includes :
-: : : both the window dimensions that are :
-: : : explicitly part of the output shape (via :
-: : : `output_window_dims`) and the window :
-: : : dimensions that are elided (via :
-: : : `elided_window_dims`). :
-|`gather_dims_to_operand_dims` | `ArraySlice<int64>` | A dimension map (the |
-: : : array is interpreted as mapping `i` to :
-: : : `gather_dims_to_operand_dims[i]`) from :
-: : : the gather indices in `gather_indices` to :
-: : : the operand index space. It has to be :
-: : : one-to-one and total. :
-
-For every index `Out` in the output tensor, we compute two things (more
-precisely described later):
-
- - An index into `gather_indices.rank` - `1` dimensions of `gather_indices`,
- which gives us a starting index of a slice, _operand slice_, in the operand
- tensor. These `gather_indices.rank` - `1` dimensions are all the dimensions
- in `gather_indices` except `index_vector_dim`.
-
- - A _window index_ that has the same rank as the operand. This index is
- composed of the values in `Out` at dimensions `output_window_dims`, embedded
- with zeroes according to `elided_window_dims`.
-
-The _window index_ is the relative index of the element in _operand slice_ that
-should be present in the output at index `Out`.
-
-The output is a tensor of rank `output_window_dims.size` + `gather_indices.rank`
-- `1`. Additionally, as a shorthand, we define `output_gather_dims` of type
-`ArraySlice<int64>` as the set of dimensions in the output shape but not in
-`output_window_dims`, in ascending order. E.g. if the output tensor has rank
-`5`, `output_window_dims` is {`2`, `4`} then `output_gather_dims` is {`0`, `1`,
-`3`}
-
-If `index_vector_dim` is equal to `gather_indices.rank` we implicitly
-consider `gather_indices` to have a trailing `1` dimension (i.e. if
-`gather_indices` was of shape `[6,7]` and `index_vector_dim` is `2` then
-we implicitly consider the shape of `gather_indices` to be `[6,7,1]`).
-
-The bounds for the output tensor along dimension `i` is computed as follows:
-
- 1. If `i` is present in `output_gather_dims` (i.e. is equal to
- `output_gather_dims[k]` for some `k`) then we pick the corresponding
- dimension bounds out of `gather_indices.shape`, skipping
- `index_vector_dim` (i.e. pick `gather_indices.shape.dims`[`k`] if `k`
- < `index_vector_dim` and `gather_indices.shape.dims`[`k`+`1`]
- otherwise).
- 2. If `i` is present in `output_window_dims` (i.e. equal to
- `output_window_dims`[`k`] for some `k`) then we pick the corresponding
- bound out of `window_bounds` after accounting for `elided_window_dims`
- (i.e. we pick `adjusted_window_bounds`[`k`] where `adjusted_window_bounds`
- is `window_bounds` with the bounds at indices `elided_window_dims`
- removed).
-
-The operand index `In` corresponding to an output index `Out` is computed as
-follows:
-
- 1. Let `G` = { `Out`[`k`] for `k` in `output_gather_dims` }. Use `G` to slice
- out vector `S` such that `S`[`i`] = `gather_indices`[Combine(`G`, `i`)]
- where Combine(A, b) inserts b at position `index_vector_dim` into A.
- Note that this is well defined even if `G` is empty -- if `G` is empty then
- `S` = `gather_indices`.
- 2. Create an index, `S`<sub>`in`</sub>, into `operand` using `S` by
- scattering `S` using the `gather_dims_to_operand_dims` map
- (`S`<sub>`in`</sub> is the starting indices for _operand slice_ mentioned
- above). More precisely:
- 1. `S`<sub>`in`</sub>[`gather_dims_to_operand_dims`[`k`]] = `S`[`k`] if `k` <
- `gather_dims_to_operand_dims.size`.
+|`start_indices` | `XlaOp` | Array containing the starting |
+: : : indices of the slices we gather.:
+|`index_vector_dim` | `int64` | The dimension in |
+: : : `start_indices` that "contains" :
+: : : the starting indices. See :
+: : : below for a detailed :
+: : : description. :
+|`offset_dims` | `ArraySlice<int64>` | The set of dimensions in the :
+: : : output shape that offset into a :
+: : : array sliced from operand. :
+|`slice_sizes` | `ArraySlice<int64>` | `slice_sizes[i]` is the bounds |
+: : : for the slice on dimension `i`.:
+|`collapsed_slice_dims` | `ArraySlice<int64>` | The set of dimensions in each :
+| : | slice that are collapsed away. :
+| : | These dimensions must have size:
+| : | 1. |
+|`start_index_map` | `ArraySlice<int64>` | A map that describes how to map|
+: : : indices in `start_indices` to :
+: : : to legal indices into operand. :
+
+For convenience, we label dimensions in the output array not in `offset_dims`
+as `batch_dims`.
+
+The output is an array of rank `batch_dims.size` + `operand.rank` -
+`collapsed_slice_dims`.size.
+
+If `index_vector_dim` is equal to `start_indices.rank` we implicitly consider
+`start_indices` to have a trailing `1` dimension (i.e. if `start_indices` was of
+shape `[6,7]` and `index_vector_dim` is `2` then we implicitly consider the
+shape of `start_indices` to be `[6,7,1]`).
+
+The bounds for the output array along dimension `i` is computed as follows:
+
+ 1. If `i` is present in `batch_dims` (i.e. is equal to `batch_dims[k]` for
+ some `k`) then we pick the corresponding dimension bounds out of
+ `start_indices.shape`, skipping `index_vector_dim` (i.e. pick
+ `start_indices.shape.dims`[`k`] if `k` < `index_vector_dim` and
+ `start_indices.shape.dims`[`k`+`1`] otherwise).
+
+ 2. If `i` is present in `offset_dims` (i.e. equal to `offset_dims`[`k`] for
+ some `k`) then we pick the corresponding bound out of `slice_sizes` after
+ accounting for `collapsed_slice_dims` (i.e. we pick
+ `adjusted_slice_sizes`[`k`] where `adjusted_slice_sizes` is `slice_sizes`
+ with the bounds at indices `collapsed_slice_dims` removed).
+
+Formally, the operand index `In` corresponding to an output index `Out` is
+computed as follows:
+
+ 1. Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out
+ vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where
+ Combine(A, b) inserts b at position `index_vector_dim` into A. Note that
+ this is well defined even if `G` is empty -- if `G` is empty then `S` =
+ `start_indices`.
+
+ 2. Create a starting index, `S`<sub>`in`</sub>, into `operand` using `S` by
+ scattering `S` using `start_index_map`. More precisely:
+ 1. `S`<sub>`in`</sub>[`start_index_map`[`k`]] = `S`[`k`] if `k` <
+ `start_index_map.size`.
2. `S`<sub>`in`</sub>[`_`] = `0` otherwise.
- 3. Create an index `W`<sub>`in`</sub> into `operand` by scattering the indices
- at the output window dimensions in `Out` according to
- the `elided_window_dims` set (`W`<sub>`in`</sub> is the _window index_
- mentioned above). More precisely:
- 1. `W`<sub>`in`</sub>[`window_dims_to_operand_dims`(`k`)] = `Out`[`k`] if
- `k` < `output_window_dims.size` (`window_dims_to_operand_dims` is
- defined below).
- 2. `W`<sub>`in`</sub>[`_`] = `0` otherwise.
- 4. `In` is `W`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise
+
+ 3. Create an index `O`<sub>`in`</sub> into `operand` by scattering the indices
+ at the offset dimensions in `Out` according to the `collapsed_slice_dims`
+ set. More precisely:
+ 1. `O`<sub>`in`</sub>[`expand_offset_dims`(`k`)] =
+ `Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size`
+ (`expand_offset_dims` is defined below).
+ 2. `O`<sub>`in`</sub>[`_`] = `0` otherwise.
+ 4. `In` is `O`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise
addition.
-`window_dims_to_operand_dims` is the monotonic function with domain [`0`,
-`output_window_dims.size`) and range [`0`, `operand.rank`) \
-`elided_window_dims`. So if, e.g., `output_window_dims.size` is `4`,
-`operand.rank` is `6` and `elided_window_dims` is {`0`, `2`} then
-`window_dims_to_operand_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}.
+`expand_offset_dims` is the monotonic function with domain [`0`, `offset.size`)
+and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g.,
+`offset.size` is `4`, `operand.rank` is `6` and `collapsed_slice_dims` is {`0`,
+`2`} then `expand_offset_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}.
### Informal Description and Examples
-`index_vector_dim` is set to `gather_indices.rank` - `1` in all of the
-examples that follow. More interesting values for `index_vector_dim`
-does not change the operation fundamentally, but makes the visual representation
-more cumbersome.
+Informally, every index `Out` in the output array corresponds to an element `E`
+in the operand array, computed as follows:
+
+ - We use the batch dimensions in `Out` to look up a starting index from
+ `start_indices`.
+
+ - We use `start_index_map` to map the starting index (which may have size less
+ than operand.rank) to a "full" starting index into operand.
+
+ - We dynamic-slice out a slice with size `slice_sizes` using the full starting
+ index.
+
+ - We reshape the slice by collapsing the `collapsed_slice_dims` dimensions.
+ Since all collapsed slice dimensions have to have bound 1 this reshape is
+ always legal.
+
+ - We use the offset dimensions in `Out` to index into this slice to get the
+ input element, `E`, corresponding to output index `Out`.
+
+`index_vector_dim` is set to `start_indices.rank` - `1` in all of the
+examples that follow. More interesting values for `index_vector_dim` does not
+change the operation fundamentally, but makes the visual representation more
+cumbersome.
To get an intuition on how all of the above fits together, let's look at an
-example that gathers 5 slices of shape `[8,6]` from a `[16,11]` tensor. The
-position of a slice into the `[16,11]` tensor can be represented as an index
+example that gathers 5 slices of shape `[8,6]` from a `[16,11]` array. The
+position of a slice into the `[16,11]` array can be represented as an index
vector of shape `S64[2]`, so the set of 5 positions can be represented as a
-`S64[5,2]` tensor.
+`S64[5,2]` array.
The behavior of the gather operation can then be depicted as an index
-transformation that takes [`G`,`W`<sub>`0`</sub>,`W`<sub>`1`</sub>], an index in
-the output shape, and maps it to an element in the input tensor in the following
+transformation that takes [`G`,`O`<sub>`0`</sub>,`O`<sub>`1`</sub>], an index in
+the output shape, and maps it to an element in the input array in the following
way:
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../../images/ops_xla_gather_0.svg">
</div>
-We first select an (`X`,`Y`) vector from the gather indices tensor using `G`.
-The element in the output tensor at index
-[`G`,`W`<sub>`0`</sub>,`W`<sub>`1`</sub>] is then the element in the input
-tensor at index [`X`+`W`<sub>`0`</sub>,`Y`+`W`<sub>`1`</sub>].
+We first select an (`X`,`Y`) vector from the gather indices array using `G`.
+The element in the output array at index
+[`G`,`O`<sub>`0`</sub>,`O`<sub>`1`</sub>] is then the element in the input
+array at index [`X`+`O`<sub>`0`</sub>,`Y`+`O`<sub>`1`</sub>].
-`window_bounds` is `[8,6]`, which decides the range of W<sub>`0`</sub> and
+`slice_sizes` is `[8,6]`, which decides the range of W<sub>`0`</sub> and
W<sub>`1`</sub>, and this in turn decides the bounds of the slice.
This gather operation acts as a batch dynamic slice with `G` as the batch
dimension.
The gather indices may be multidimensional. For instance, a more general
-version of the example above using a "gather indices" tensor of shape `[4,5,2]`
+version of the example above using a "gather indices" array of shape `[4,5,2]`
would translate indices like this:
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
@@ -1298,25 +1288,25 @@ would translate indices like this:
</div>
Again, this acts as a batch dynamic slice `G`<sub>`0`</sub> and
-`G`<sub>`1`</sub> as the batch dimensions. The window bounds are still `[8,6]`.
+`G`<sub>`1`</sub> as the batch dimensions. The slice size is still `[8,6]`.
The gather operation in XLA generalizes the informal semantics outlined above in
the following ways:
- 1. We can configure which dimensions in the output shape are the window
- dimensions (dimensions containing `W`<sub>`0`</sub>, `W`<sub>`1`</sub> in
- the last example). The output gather dimensions (dimensions containing
+ 1. We can configure which dimensions in the output shape are the offset
+ dimensions (dimensions containing `O`<sub>`0`</sub>, `O`<sub>`1`</sub> in
+ the last example). The output batch dimensions (dimensions containing
`G`<sub>`0`</sub>, `G`<sub>`1`</sub> in the last example) are defined to be
- the output dimensions that are not window dimensions.
+ the output dimensions that are not offset dimensions.
- 2. The number of output window dimensions explicitly present in the output
+ 2. The number of output offset dimensions explicitly present in the output
shape may be smaller than the input rank. These "missing" dimensions, which
- are listed explicitly as `elided_window_dims`, must have a window bound of
- `1`. Since they have a window bound of `1` the only valid index for them is
+ are listed explicitly as `collapsed_slice_dims`, must have a slice size of
+ `1`. Since they have a slice size of `1` the only valid index for them is
`0` and eliding them does not introduce ambiguity.
- 3. The slice extracted from the "Gather Indices" tensor ((`X`, `Y`) in the last
- example) may have fewer elements than the input tensor rank, and an explicit
+ 3. The slice extracted from the "Gather Indices" array ((`X`, `Y`) in the last
+ example) may have fewer elements than the input array rank, and an explicit
mapping dictates how the index should be expanded to have the same rank as
the input.
@@ -1327,20 +1317,19 @@ As a final example, we use (2) and (3) to implement `tf.gather_nd`:
</div>
`G`<sub>`0`</sub> and `G`<sub>`1`</sub> are used to slice out a starting index
-from the gather indices tensor as usual, except the starting index has only one
-element, `X`. Similarly, there is only one output window index with the value
-`W`<sub>`0`</sub>. However, before being used as indices into the input tensor,
-these are expanded in accordance to "Gather Index Mapping"
-(`gather_dims_to_operand_dims` in the formal description) and "Window Mapping"
-(`window_dims_to_operand_dims` in the formal description) into
-[`0`,`W`<sub>`0`</sub>] and [`X`,`0`] respectively, adding up to
-[`X`,`W`<sub>`0`</sub>]. In other words, the output index
-[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`W`<sub>`0`</sub>] maps to the input index
+from the gather indices array as usual, except the starting index has only one
+element, `X`. Similarly, there is only one output offset index with the value
+`O`<sub>`0`</sub>. However, before being used as indices into the input array,
+these are expanded in accordance to "Gather Index Mapping" (`start_index_map` in
+the formal description) and "Offset Mapping" (`expand_offset_dims` in the formal
+description) into [`0`,`O`<sub>`0`</sub>] and [`X`,`0`] respectively, adding up
+to [`X`,`O`<sub>`0`</sub>]. In other words, the output index
+[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`O`<sub>`0`</sub>] maps to the input index
[`GatherIndices`[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`0`],`X`] which gives us
the semantics for `tf.gather_nd`.
-`window_bounds` for this case is `[1,11]`. Intuitively this means that every
-index `X` in the gather indices tensor picks an entire row and the result is the
+`slice_sizes` for this case is `[1,11]`. Intuitively this means that every
+index `X` in the gather indices array picks an entire row and the result is the
concatenation of all these rows.
## GetTupleElement