aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/tf2xla/kernels/gather_op.cc32
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.cc14
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc26
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h12
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc56
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc11
-rw-r--r--tensorflow/compiler/xla/service/gather_expander.cc180
-rw-r--r--tensorflow/compiler/xla/service/gather_expander_test.cc16
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto2
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc161
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc112
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc24
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h8
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc66
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc57
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h16
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc35
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc2
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.cc43
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.h2
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis_test.cc300
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc201
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h4
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc269
-rw-r--r--tensorflow/compiler/xla/tests/gather_operation_test.cc312
-rw-r--r--tensorflow/compiler/xla/xla_data.proto18
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md267
28 files changed, 1103 insertions, 1153 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
index 35de96e0aa..44140304fd 100644
--- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
@@ -95,11 +95,11 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
// operand = s32[3,3] parameter(0)
// indices = s32[2] parameter(1)
// gather = s32[3,2] gather(operand, indices),
- // output_window_dims={0},
- // elided_window_dims={1},
- // gather_dims_to_operand_dims={1},
+ // offset_dims={0},
+ // collapsed_slice_dims={1},
+ // start_index_map={1},
// index_vector_dim=1,
- // window_bounds={3, 1}
+ // slice_sizes={3, 1}
//
//
// Example of an N-D gather pulling out slices of shape [1,1,2] out of a
@@ -108,42 +108,42 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
// operand = s32[3,3,2] parameter(0)
// indices = s32[2,2] parameter(1)
// gather = s32[2,2] gather(operand, indices),
- // output_window_dims={1},
- // elided_window_dims={0,1},
- // gather_dims_to_operand_dims={0,1},
+ // offset_dims={1},
+ // collapsed_slice_dims={0,1},
+ // start_index_map={0,1},
// index_vector_dim=0,
- // window_bounds={1,1,2}
+ // slice_sizes={1,1,2}
xla::GatherDimensionNumbers dim_numbers;
- std::vector<int64> window_bounds;
- window_bounds.reserve(input_shape.dims());
+ std::vector<int64> slice_sizes;
+ slice_sizes.reserve(input_shape.dims());
for (int64 i = 0; i < input_shape.dims(); i++) {
int64 window_bound;
if (axis <= i && i < (axis + num_index_dims)) {
- dim_numbers.add_elided_window_dims(i);
+ dim_numbers.add_collapsed_slice_dims(i);
window_bound = 1;
} else {
window_bound = input_shape.dim_size(i);
}
- window_bounds.push_back(window_bound);
+ slice_sizes.push_back(window_bound);
if (i < axis) {
- dim_numbers.add_output_window_dims(i);
+ dim_numbers.add_offset_dims(i);
} else if (i >= (axis + num_index_dims)) {
int64 indices_rank =
indices_are_nd ? (indices_shape.dims() - 1) : indices_shape.dims();
- dim_numbers.add_output_window_dims(i + indices_rank - num_index_dims);
+ dim_numbers.add_offset_dims(i + indices_rank - num_index_dims);
}
}
dim_numbers.set_index_vector_dim(indices_are_nd ? (indices_shape.dims() - 1)
: indices_shape.dims());
for (int64 i = axis; i < axis + num_index_dims; i++) {
- dim_numbers.add_gather_dims_to_operand_dims(i);
+ dim_numbers.add_start_index_map(i);
}
- *gather_output = xla::Gather(input, indices, dim_numbers, window_bounds);
+ *gather_output = xla::Gather(input, indices, dim_numbers, slice_sizes);
return Status::OK();
}
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
index 04fa10108c..febb638e5e 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
@@ -57,7 +57,7 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) {
// We can grab entire blocks using gather
if (n > block_size) {
// Construct the starting indices of the diagonal blocks
- auto gather_indices =
+ auto start_indices =
Transpose(Broadcast(Mul(Iota(builder, xla::S32, num_blocks),
xla::ConstantR0<int32>(builder, block_size)),
/*broadcast_sizes=*/{2}),
@@ -65,13 +65,13 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) {
// Gather the diagonal blocks
xla::GatherDimensionNumbers dim_numbers;
- dim_numbers.add_output_window_dims(ndims - 1);
- dim_numbers.add_output_window_dims(ndims);
- dim_numbers.add_gather_dims_to_operand_dims(ndims - 2);
- dim_numbers.add_gather_dims_to_operand_dims(ndims - 1);
+ dim_numbers.add_offset_dims(ndims - 1);
+ dim_numbers.add_offset_dims(ndims);
+ dim_numbers.add_start_index_map(ndims - 2);
+ dim_numbers.add_start_index_map(ndims - 1);
dim_numbers.set_index_vector_dim(1);
- diag_blocks = Gather(a, gather_indices, dim_numbers,
- /*window_bounds=*/{block_size, block_size});
+ diag_blocks = Gather(a, start_indices, dim_numbers,
+ /*slice_sizes=*/{block_size, block_size});
}
// The last block might be smaller than the block size,
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index aa47f992bc..4dffab3c2c 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -1631,27 +1631,27 @@ XlaOp XlaBuilder::While(const XlaComputation& condition,
});
}
-XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices,
+XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds) {
+ tensorflow::gtl::ArraySlice<int64> slice_sizes) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input));
- TF_ASSIGN_OR_RETURN(const Shape& gather_indices_shape,
- GetShape(gather_indices));
+ TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape,
+ GetShape(start_indices));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
- ShapeInference::InferGatherShape(input_shape, gather_indices_shape,
- dimension_numbers, window_bounds));
+ ShapeInference::InferGatherShape(input_shape, start_indices_shape,
+ dimension_numbers, slice_sizes));
*instr.mutable_gather_dimension_numbers() = dimension_numbers;
- for (int64 bound : window_bounds) {
- instr.add_gather_window_bounds(bound);
+ for (int64 bound : slice_sizes) {
+ instr.add_gather_slice_sizes(bound);
}
return AddInstruction(std::move(instr), HloOpcode::kGather,
- {input, gather_indices});
+ {input, start_indices});
});
}
@@ -2906,11 +2906,11 @@ XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
mantissa_bits);
}
-XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
+XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds) {
- return input.builder()->Gather(input, gather_indices, dimension_numbers,
- window_bounds);
+ tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ return input.builder()->Gather(input, start_indices, dimension_numbers,
+ slice_sizes);
}
XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 78aec770a6..469d5048b2 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -877,9 +877,9 @@ class XlaBuilder {
const int mantissa_bits);
// Enqueues a Gather node onto the computation.
- XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
+ XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
// Enqueues a Scatter node onto the computation.
XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
@@ -1328,9 +1328,9 @@ class XlaBuilder {
const XlaComputation& false_computation);
friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
const int mantissa_bits);
- friend XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
+ friend XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
friend XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
const XlaOp& updates,
const XlaComputation& update_computation,
@@ -2024,9 +2024,9 @@ XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
const int mantissa_bits);
// Enqueues a Gather node onto the computation.
-XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
+XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
// Enqueues a Scatter node onto the computation.
XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
index 8cbdc36f84..e6130c7d76 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
@@ -792,11 +792,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
gather = s32[3,2] gather(operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
one = s32[] constant(1)
one_broadcasted = s32[3,2] broadcast(one), dimensions={}
ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted)
@@ -808,11 +808,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,3,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
one = s32[] constant(1)
one_broadcasted = s32[2,3,2] broadcast(one), dimensions={}
ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted)
@@ -824,11 +824,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
- output_window_dims={},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=2,
- window_bounds={1, 1}
+ slice_sizes={1, 1}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
@@ -840,11 +840,11 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
@@ -856,11 +856,11 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
@@ -872,11 +872,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
gather = s32[1,1] gather(operand, indices),
- output_window_dims={0,1},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={0,1},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
one = s32[] constant(1)
one_broadcasted = s32[1,1] broadcast(one), dimensions={}
ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted)
@@ -888,11 +888,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,1,1] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
one = s32[] constant(1)
one_broadcasted = s32[2,1,1] broadcast(one), dimensions={}
ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted)
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 2e9d6be2de..891ae42141 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -1672,22 +1672,21 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
std::vector<int64> operand_to_output_dim(operand_shape.dimensions_size(), -1);
for (int64 i = 0, e = operand_shape.dimensions_size(), operand_index_dim = 0;
i < e; i++) {
- if (c_binary_search(dim_numbers.elided_window_dims(), i)) {
+ if (c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
operand_index.push_back(index.GetConstantWithIndexType(0));
} else {
- int64 output_window_dim =
- dim_numbers.output_window_dims(operand_index_dim++);
+ int64 output_window_dim = dim_numbers.offset_dims(operand_index_dim++);
operand_to_output_dim[i] = output_window_dim;
operand_index.push_back(index[output_window_dim]);
}
}
- // This is the index of the index vector in the gather_indices tensor.
+ // This is the index of the index vector in the start_indices tensor.
IrArray::Index gather_index_index(index_type);
{
std::vector<llvm::Value*> gather_index_index_components;
for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) {
- if (!c_binary_search(dim_numbers.output_window_dims(), i)) {
+ if (!c_binary_search(dim_numbers.offset_dims(), i)) {
gather_index_index.push_back(index[i]);
}
}
@@ -1700,7 +1699,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) {
llvm::Value* gather_dim_component_extended =
b_->CreateSExtOrTrunc(index_component, index_type);
- int64 operand_dim = dim_numbers.gather_dims_to_operand_dims(dim);
+ int64 operand_dim = dim_numbers.start_index_map(dim);
int64 output_dim = operand_to_output_dim[operand_dim];
// If 'output_dim' is -1, it means 'operand_dim' is an elided window dim.
// This means we set the iteration index to 0, so for the purpose of the
diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc
index e3a42d0d06..9370c88710 100644
--- a/tensorflow/compiler/xla/service/gather_expander.cc
+++ b/tensorflow/compiler/xla/service/gather_expander.cc
@@ -27,85 +27,85 @@ namespace xla {
using tensorflow::gtl::ArraySlice;
static StatusOr<HloInstruction*> TransposeIndexVectorDimToLast(
- HloInstruction* gather_indices, int64 index_vector_dim) {
- const Shape& gather_indices_shape = gather_indices->shape();
+ HloInstruction* start_indices, int64 index_vector_dim) {
+ const Shape& start_indices_shape = start_indices->shape();
- if (gather_indices_shape.dimensions_size() == index_vector_dim) {
- return gather_indices;
+ if (start_indices_shape.dimensions_size() == index_vector_dim) {
+ return start_indices;
}
- if (index_vector_dim == (gather_indices_shape.dimensions_size() - 1)) {
- return gather_indices;
+ if (index_vector_dim == (start_indices_shape.dimensions_size() - 1)) {
+ return start_indices;
}
std::vector<int64> permutation;
- permutation.reserve(gather_indices_shape.dimensions_size());
- for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) {
+ permutation.reserve(start_indices_shape.dimensions_size());
+ for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) {
if (i != index_vector_dim) {
permutation.push_back(i);
}
}
permutation.push_back(index_vector_dim);
- return MakeTransposeHlo(gather_indices, permutation);
+ return MakeTransposeHlo(start_indices, permutation);
}
-// Canonicalizes the gather_indices tensors so that we only have deal with some
+// Canonicalizes the start_indices tensors so that we only have deal with some
// specific cases in the while loop that does the heavy lifting.
//
// See the "High Level Algorithm" section for a broader picture.
static StatusOr<HloInstruction*> CanonicalizeGatherIndices(
- HloInstruction* gather_indices, int64 index_vector_dim) {
+ HloInstruction* start_indices, int64 index_vector_dim) {
// Transpose the non-index-vector dimensions to the front.
TF_ASSIGN_OR_RETURN(
- HloInstruction * transposed_gather_indices,
- TransposeIndexVectorDimToLast(gather_indices, index_vector_dim));
+ HloInstruction * transposed_start_indices,
+ TransposeIndexVectorDimToLast(start_indices, index_vector_dim));
bool indices_are_scalar =
- index_vector_dim == gather_indices->shape().dimensions_size();
+ index_vector_dim == start_indices->shape().dimensions_size();
- // The number of dimensions in gather_indices that are index dimensions.
- const int64 index_dims_in_gather_indices = indices_are_scalar ? 0 : 1;
+ // The number of dimensions in start_indices that are index dimensions.
+ const int64 index_dims_in_start_indices = indices_are_scalar ? 0 : 1;
- // If there is only one index (i.e. gather_indices has rank 1 and this gather
+ // If there is only one index (i.e. start_indices has rank 1 and this gather
// is really just a dynamic slice) add a leading degenerate dimension for
// uniformity. Otherwise create a "collapsed" leading dimension that subsumes
// all of the non-index-vector dimensions.
- const Shape& shape = transposed_gather_indices->shape();
- if (shape.dimensions_size() == index_dims_in_gather_indices) {
- return PrependDegenerateDims(transposed_gather_indices, 1);
+ const Shape& shape = transposed_start_indices->shape();
+ if (shape.dimensions_size() == index_dims_in_start_indices) {
+ return PrependDegenerateDims(transposed_start_indices, 1);
} else {
- // Collapse all but the dimensions (0 or 1) in gather_indices containing the
+ // Collapse all but the dimensions (0 or 1) in start_indices containing the
// index vectors.
return CollapseFirstNDims(
- transposed_gather_indices,
- shape.dimensions_size() - index_dims_in_gather_indices);
+ transposed_start_indices,
+ shape.dimensions_size() - index_dims_in_start_indices);
}
}
// Expands out or contracts away the gather dimensions in the accumulator
// produced by the while loop.
-static StatusOr<HloInstruction*> AdjustGatherDimsInAccumulator(
- const Shape& gather_indices_shape, HloInstruction* accumulator,
+static StatusOr<HloInstruction*> AdjustBatchDimsInAccumulator(
+ const Shape& start_indices_shape, HloInstruction* accumulator,
int64 index_vector_dim) {
- std::vector<int64> output_gather_dim_bounds;
- output_gather_dim_bounds.reserve(gather_indices_shape.dimensions_size());
- for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) {
+ std::vector<int64> batch_dim_bounds;
+ batch_dim_bounds.reserve(start_indices_shape.dimensions_size());
+ for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) {
if (i != index_vector_dim) {
- output_gather_dim_bounds.push_back(gather_indices_shape.dimensions(i));
+ batch_dim_bounds.push_back(start_indices_shape.dimensions(i));
}
}
- if (output_gather_dim_bounds.empty()) {
- // If output_gather_dim_bounds is empty we must be lowering a (effectively)
+ if (batch_dim_bounds.empty()) {
+ // If batch_dim_bounds is empty we must be lowering a (effectively)
// dynamic-slice. In that case, there is a leading degenerate gather
// dimension that we added to make this special case play well with the
// general while loop which we need to remove now.
return ElideDegenerateDims(accumulator, {0});
}
- return ExpandFirstDimIntoNDims(accumulator, output_gather_dim_bounds);
+ return ExpandFirstDimIntoNDims(accumulator, batch_dim_bounds);
}
-// Expand an index vector from the gather_indices tensor into a vector that can
+// Expand an index vector from the start_indices tensor into a vector that can
// be used to dynamic-slice out of the gather operand.
static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
HloInstruction* index_vector, const GatherDimensionNumbers& dim_numbers,
@@ -121,10 +121,8 @@ static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
std::vector<HloInstruction*> expanded_index_components;
for (int i = 0; i < operand_rank; i++) {
- int64 index_vector_dim_index =
- FindIndex(dim_numbers.gather_dims_to_operand_dims(), i);
- if (index_vector_dim_index !=
- dim_numbers.gather_dims_to_operand_dims_size()) {
+ int64 index_vector_dim_index = FindIndex(dim_numbers.start_index_map(), i);
+ if (index_vector_dim_index != dim_numbers.start_index_map_size()) {
TF_ASSIGN_OR_RETURN(
HloInstruction * component_to_concat,
MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index},
@@ -147,10 +145,10 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
const GatherDimensionNumbers& dim_numbers = gather.gather_dimension_numbers();
CHECK_EQ(incoming_loop_state.size(), 3);
HloInstruction* const operand = incoming_loop_state[0];
- HloInstruction* const gather_indices = incoming_loop_state[1];
+ HloInstruction* const start_indices = incoming_loop_state[1];
HloInstruction* const output_accumulator = incoming_loop_state[2];
- bool has_scalar_indices = gather_indices->shape().dimensions_size() == 1;
+ bool has_scalar_indices = start_indices->shape().dimensions_size() == 1;
CHECK_EQ(has_scalar_indices,
dim_numbers.index_vector_dim() ==
gather.operand(1)->shape().dimensions_size());
@@ -163,24 +161,24 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
HloInstruction* index_vector;
if (has_scalar_indices) {
- // In this case gather_indices has rank 1 and induction_var_as_vector (of
+ // In this case start_indices has rank 1 and induction_var_as_vector (of
// shape {1}) is an index into this rank 1 tensor.
TF_ASSIGN_OR_RETURN(
index_vector,
- MakeDynamicSliceHlo(gather_indices, induction_var_as_vector, {1}));
+ MakeDynamicSliceHlo(start_indices, induction_var_as_vector, {1}));
} else {
- // In this case gather_indices has rank 2 and induction_var_as_vector (of
+ // In this case start_indices has rank 2 and induction_var_as_vector (of
// shape {1}) is an index into just the first dimension of this rank 2
// tensor.
TF_ASSIGN_OR_RETURN(
- HloInstruction * index_into_gather_indices,
+ HloInstruction * index_into_start_indices,
PadVectorWithZeros(induction_var_as_vector,
/*zeros_to_prepend=*/0, /*zeros_to_append=*/1));
- int64 index_vector_size = gather_indices->shape().dimensions(1);
+ int64 index_vector_size = start_indices->shape().dimensions(1);
TF_ASSIGN_OR_RETURN(
HloInstruction * index_vector_2d,
- MakeDynamicSliceHlo(gather_indices, index_into_gather_indices,
+ MakeDynamicSliceHlo(start_indices, index_into_start_indices,
{1, index_vector_size}));
TF_ASSIGN_OR_RETURN(index_vector,
@@ -194,26 +192,26 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice,
MakeDynamicSliceHlo(operand, gathered_slice_start,
- gather.gather_window_bounds()));
+ gather.gather_slice_sizes()));
TF_ASSIGN_OR_RETURN(
- HloInstruction * gathered_slice_with_dims_elided,
+ HloInstruction* const gathered_slice_with_dims_collapsed,
ElideDegenerateDims(gathered_slice,
- AsInt64Slice(dim_numbers.elided_window_dims())));
+ AsInt64Slice(dim_numbers.collapsed_slice_dims())));
TF_ASSIGN_OR_RETURN(
- HloInstruction * gathered_slice_for_update,
- PrependDegenerateDims(gathered_slice_with_dims_elided, 1));
+ HloInstruction* const gathered_slice_for_update,
+ PrependDegenerateDims(gathered_slice_with_dims_collapsed, 1));
TF_ASSIGN_OR_RETURN(
- HloInstruction * index_vector_into_accumulator,
+ HloInstruction* const index_vector_into_accumulator,
PadVectorWithZeros(
induction_var_as_vector, /*zeros_to_prepend=*/0,
/*zeros_to_append=*/
- gathered_slice_with_dims_elided->shape().dimensions_size()));
+ gathered_slice_with_dims_collapsed->shape().dimensions_size()));
TF_ASSIGN_OR_RETURN(
- HloInstruction * updated_accumulator,
+ HloInstruction* const updated_accumulator,
MakeDynamicUpdateSliceHlo(output_accumulator, gathered_slice_for_update,
index_vector_into_accumulator));
@@ -221,19 +219,19 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
// WhileUtil::MakeCountedLoop functions takes care of the induction variable
// and the while loop exit condition.
return StatusOr<std::vector<HloInstruction*>>{
- {operand, gather_indices, updated_accumulator}};
+ {operand, start_indices, updated_accumulator}};
}
static StatusOr<HloInstruction*> CreateGatherLoopAccumulatorInitValue(
HloComputation* computation, PrimitiveType element_type,
- ArraySlice<int64> window_bounds, int64 gather_loop_trip_count,
+ ArraySlice<int64> slice_sizes, int64 gather_loop_trip_count,
const GatherDimensionNumbers& dim_numbers) {
std::vector<int64> accumulator_state_shape_dims;
- accumulator_state_shape_dims.reserve(1 + window_bounds.size());
+ accumulator_state_shape_dims.reserve(1 + slice_sizes.size());
accumulator_state_shape_dims.push_back(gather_loop_trip_count);
- for (int64 i = 0; i < window_bounds.size(); i++) {
- if (!c_binary_search(dim_numbers.elided_window_dims(), i)) {
- accumulator_state_shape_dims.push_back(window_bounds[i]);
+ for (int64 i = 0; i < slice_sizes.size(); i++) {
+ if (!c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
+ accumulator_state_shape_dims.push_back(slice_sizes[i]);
}
}
return BroadcastZeros(computation, element_type,
@@ -241,23 +239,23 @@ static StatusOr<HloInstruction*> CreateGatherLoopAccumulatorInitValue(
}
// `accumulator` is almost the tensor the gather operation would have produced,
-// except that it has the dimensions in the wrong order -- the gather dimensions
-// are the major dimensions and the window dimensions are the minor dimensions.
+// except that it has the dimensions in the wrong order -- the batch dimensions
+// are the major dimensions and the offset dimensions are the minor dimensions.
// Fix this up with a transpose.
-static StatusOr<HloInstruction*> PermuteGatherAndWindowDims(
- HloInstruction* accumulator, ArraySlice<int64> output_window_dims,
+static StatusOr<HloInstruction*> PermuteBatchAndOffsetDims(
+ HloInstruction* accumulator, ArraySlice<int64> offset_dims,
int64 output_rank) {
std::vector<int64> permutation;
permutation.reserve(output_rank);
- int64 gather_idx_counter = 0;
- int64 window_idx_counter = output_rank - output_window_dims.size();
+ int64 batch_idx_counter = 0;
+ int64 offset_idx_counter = output_rank - offset_dims.size();
for (int64 i = 0; i < output_rank; i++) {
- bool is_window_dim = c_binary_search(output_window_dims, i);
- if (is_window_dim) {
- permutation.push_back(window_idx_counter++);
+ bool is_offset_dim = c_binary_search(offset_dims, i);
+ if (is_offset_dim) {
+ permutation.push_back(offset_idx_counter++);
} else {
- permutation.push_back(gather_idx_counter++);
+ permutation.push_back(batch_idx_counter++);
}
}
@@ -268,11 +266,11 @@ static StatusOr<HloInstruction*> PermuteGatherAndWindowDims(
//
// We follow the following steps in sequence:
//
-// 1. We canonicalize the gather_indices tensor such that it has rank
+// 1. We canonicalize the start_indices tensor such that it has rank
// 2 (i.e. is a matrix) where each row is an index vector into the
// operand.
// 2. We iterate over the set of indices in the canonicalized
-// gather_indices tensor using a while loop, accumulating slices
+// start_indices tensor using a while loop, accumulating slices
// of the operand tensor into an accumulator using
// DynamicUpdateSlice.
// 3. The accumulator result from the while loop from (2) is then
@@ -287,11 +285,11 @@ static StatusOr<HloInstruction*> PermuteGatherAndWindowDims(
// operand = s32[3,3] parameter(0)
// indices = s32[2,2] parameter(1)
// ROOT gather = s32[2,3,2] gather(operand, indices),
-// output_window_dims={1},
-// elided_window_dims={1},
-// gather_dims_to_operand_dims={1},
+// offset_dims={1},
+// collapsed_slice_dims={1},
+// start_index_map={1},
// index_vector_dim=2,
-// window_bounds={3, 1}
+// slice_sizes={3, 1}
// }
//
// We'd first reshape indices to s32[4,1], where each row is an index
@@ -305,8 +303,8 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather(
HloComputation* computation = gather_instr->parent();
HloInstruction* operand = gather_instr->mutable_operand(0);
- HloInstruction* gather_indices = gather_instr->mutable_operand(1);
- const Shape& gather_indices_shape = gather_indices->shape();
+ HloInstruction* start_indices = gather_instr->mutable_operand(1);
+ const Shape& start_indices_shape = start_indices->shape();
const Shape& output_shape = gather_instr->shape();
int64 output_rank = output_shape.dimensions_size();
@@ -314,9 +312,9 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather(
gather_instr->gather_dimension_numbers();
int64 gather_loop_trip_count = 1;
- for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) {
+ for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) {
if (i != dim_numbers.index_vector_dim()) {
- gather_loop_trip_count *= gather_indices_shape.dimensions(i);
+ gather_loop_trip_count *= start_indices_shape.dimensions(i);
}
}
@@ -327,24 +325,24 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather(
gather_instr->ToString().c_str());
}
- TF_ASSIGN_OR_RETURN(HloInstruction * canonical_gather_indices,
- CanonicalizeGatherIndices(
- gather_indices, dim_numbers.index_vector_dim()));
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * canonical_start_indices,
+ CanonicalizeGatherIndices(start_indices, dim_numbers.index_vector_dim()));
CHECK_EQ(gather_loop_trip_count,
- canonical_gather_indices->shape().dimensions(0));
+ canonical_start_indices->shape().dimensions(0));
TF_ASSIGN_OR_RETURN(
HloInstruction * accumulator_init,
CreateGatherLoopAccumulatorInitValue(
computation, output_shape.element_type(),
- gather_instr->gather_window_bounds(), gather_loop_trip_count,
+ gather_instr->gather_slice_sizes(), gather_loop_trip_count,
gather_instr->gather_dimension_numbers()));
StatusOr<std::vector<HloInstruction*>> gather_loop_result_or_error =
WhileUtil::MakeCountedLoop(
computation, gather_loop_trip_count,
- {operand, canonical_gather_indices, accumulator_init},
+ {operand, canonical_start_indices, accumulator_init},
[&](HloInstruction* indvar,
const std::vector<HloInstruction*>& loop_state) {
return GatherLoopBody(*gather_instr, indvar, loop_state);
@@ -356,13 +354,13 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather(
HloInstruction* accumulator_result = gather_loop_result.back();
TF_ASSIGN_OR_RETURN(
- HloInstruction * accumulator_with_output_gather_dims_decanonicalized,
- AdjustGatherDimsInAccumulator(gather_indices->shape(), accumulator_result,
- dim_numbers.index_vector_dim()));
+ HloInstruction* const accumulator_with_batch_dims_decanonicalized,
+ AdjustBatchDimsInAccumulator(start_indices->shape(), accumulator_result,
+ dim_numbers.index_vector_dim()));
- return PermuteGatherAndWindowDims(
- accumulator_with_output_gather_dims_decanonicalized,
- AsInt64Slice(dim_numbers.output_window_dims()), output_rank);
+ return PermuteBatchAndOffsetDims(accumulator_with_batch_dims_decanonicalized,
+ AsInt64Slice(dim_numbers.offset_dims()),
+ output_rank);
}
StatusOr<bool> GatherExpander::Run(HloModule* module) {
diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc
index 020ffcd106..141dd4d6f1 100644
--- a/tensorflow/compiler/xla/service/gather_expander_test.cc
+++ b/tensorflow/compiler/xla/service/gather_expander_test.cc
@@ -28,11 +28,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2147483647,5] parameter(1)
ROOT gather = s32[2147483647,3,5] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
@@ -55,11 +55,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[3,2] gather(operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 9d24b42401..fa218657fe 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -139,7 +139,7 @@ message HloInstructionProto {
// Gather dimension numbers.
xla.GatherDimensionNumbers gather_dimension_numbers = 33;
- repeated int64 gather_window_bounds = 34;
+ repeated int64 gather_slice_sizes = 34;
// Compute Host.
string channel_name = 41;
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 51353eea6e..36d6a2eed6 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -555,43 +555,39 @@ Status HloEvaluator::HandleTuple(HloInstruction* tuple) {
return Status::OK();
}
-// Returns an ShapeUtil::IndexIterationSpace that iterates over the output
-// gather dimensions while keeping the rest of the output dimensions clamped to
-// 0.
-ShapeUtil::IndexIterationSpace IterationSpaceForOutputGatherIndices(
+// Returns an ShapeUtil::IndexIterationSpace that iterates over the output batch
+// dimensions while keeping the rest of the output dimensions clamped to 0.
+ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices(
const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) {
int64 output_rank = output_shape.dimensions_size();
std::vector<int64> index_base(output_rank, 0);
std::vector<int64> index_count;
index_count.reserve(output_rank);
for (int64 i = 0; i < output_rank; i++) {
- bool is_output_gather_dim =
- !c_binary_search(dim_numbers.output_window_dims(), i);
- index_count.push_back(is_output_gather_dim ? output_shape.dimensions(i)
- : 1);
+ bool is_output_batch_dim = !c_binary_search(dim_numbers.offset_dims(), i);
+ index_count.push_back(is_output_batch_dim ? output_shape.dimensions(i) : 1);
}
return {std::move(index_base), std::move(index_count),
std::vector<int64>(output_rank, 1)};
}
-// Return an ShapeUtil::IndexIterationSpace that iterates over the output window
+// Return an ShapeUtil::IndexIterationSpace that iterates over the output slice
// dimensions while keeping the rest of the output dimensions clamped to 0.
-ShapeUtil::IndexIterationSpace IterationSpaceForOutputWindowIndices(
- int64 output_rank, ArraySlice<int64> window_bounds,
+ShapeUtil::IndexIterationSpace IterationSpaceForOutputOffsetIndices(
+ int64 output_rank, ArraySlice<int64> slice_sizes,
const GatherDimensionNumbers& dim_numbers) {
std::vector<int64> index_base(output_rank, 0);
std::vector<int64> index_count(output_rank, 1);
- int64 window_bounds_idx = 0;
+ int64 slice_sizes_idx = 0;
for (int64 i = 0; i < output_rank; i++) {
- bool is_output_window_dim =
- c_binary_search(dim_numbers.output_window_dims(), i);
+ bool is_output_window_dim = c_binary_search(dim_numbers.offset_dims(), i);
if (is_output_window_dim) {
- while (c_binary_search(dim_numbers.elided_window_dims(),
- window_bounds_idx)) {
- window_bounds_idx++;
+ while (c_binary_search(dim_numbers.collapsed_slice_dims(),
+ slice_sizes_idx)) {
+ slice_sizes_idx++;
}
- index_count[i] = window_bounds[window_bounds_idx++];
+ index_count[i] = slice_sizes[slice_sizes_idx++];
}
}
@@ -599,30 +595,30 @@ ShapeUtil::IndexIterationSpace IterationSpaceForOutputWindowIndices(
std::vector<int64>(output_rank, 1)};
}
-// This functor computes the contribution of gather_indices to an input index
+// This functor computes the contribution of start_indices to an input index
// corresponding to an output index. That is, given an output index I, it picks
-// out the gather output indices in I and uses them to look up a gather index,
-// G, from the gather indices tensor, and expands G into the input space
-// according to gather_dims_to_operand_dims.
-class OutputGatherIndexToInputIndex {
+// out the batch indices in I and uses them to look up a starting index, G, from
+// the start indices tensor, and expands G into the input space according to
+// start_index_map.
+class OutputBatchIndexToInputIndex {
public:
// The constructor does some setup work that is amortized across all
// iterations.
- explicit OutputGatherIndexToInputIndex(
+ explicit OutputBatchIndexToInputIndex(
const GatherDimensionNumbers* dim_numbers, const Shape& input_shape,
- const Shape& output_shape, const Literal* gather_indices)
- : dim_numbers_(*dim_numbers), gather_indices_(*gather_indices) {
+ const Shape& output_shape, const Literal* start_indices)
+ : dim_numbers_(*dim_numbers), start_indices_(*start_indices) {
for (int64 i = 0; i < output_shape.dimensions_size(); i++) {
- output_dim_is_gather_dims_.push_back(
- !c_binary_search(dim_numbers_.output_window_dims(), i));
+ output_dim_is_batch_dims_.push_back(
+ !c_binary_search(dim_numbers_.offset_dims(), i));
}
for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
int64 index_of_input_dim_in_index_vector =
- std::distance(dim_numbers_.gather_dims_to_operand_dims().begin(),
- c_find(dim_numbers_.gather_dims_to_operand_dims(), i));
+ std::distance(dim_numbers_.start_index_map().begin(),
+ c_find(dim_numbers_.start_index_map(), i));
if (index_of_input_dim_in_index_vector ==
- dim_numbers_.gather_dims_to_operand_dims_size()) {
+ dim_numbers_.start_index_map_size()) {
input_dim_value_to_index_vector_.push_back(-1);
} else {
input_dim_value_to_index_vector_.push_back(
@@ -630,14 +626,14 @@ class OutputGatherIndexToInputIndex {
}
}
- index_vector_index_.resize(gather_indices_.shape().dimensions_size());
+ index_vector_index_.resize(start_indices_.shape().dimensions_size());
input_index_.resize(input_shape.dimensions_size());
int64 index_vector_size =
- gather_indices_.shape().dimensions(dim_numbers_.index_vector_dim());
+ start_indices_.shape().dimensions(dim_numbers_.index_vector_dim());
index_vector_.resize(index_vector_size);
}
- // Returns the contribution of gather_indices to the input index corresponding
+ // Returns the contribution of start_indices to the input index corresponding
// to output_index. See gather_inner_loop_body.
//
// This is conceptually a stateless transformation from output_index to the
@@ -659,7 +655,7 @@ class OutputGatherIndexToInputIndex {
}
private:
- // Propagates the gather index dimensions from the output index into
+ // Propagates the batch dimensions from the output index into
// index_vector_index_ by mutating index_vector_index_ in place. Does not
// update the dim_numbers.index_vector_dim() dimension -- that's the dimension
// we iterate over in FetchIndexVector.
@@ -667,7 +663,7 @@ class OutputGatherIndexToInputIndex {
ArraySlice<int64> output_index) {
int64 index_vector_index_i = 0;
for (int64 i = 0, e = output_index.size(); i < e; i++) {
- if (!output_dim_is_gather_dims_[i]) {
+ if (!output_dim_is_batch_dims_[i]) {
continue;
}
@@ -679,14 +675,14 @@ class OutputGatherIndexToInputIndex {
}
}
- // Populates index_vector_ by iterating over gather_indices_ according to
+ // Populates index_vector_ by iterating over start_indices_ according to
// index_vector_index_.
Status FetchIndexVector() {
int64 index_vector_dim = dim_numbers_.index_vector_dim();
for (int64 i = 0, e = index_vector_.size(); i < e; i++) {
index_vector_index_[index_vector_dim] = i;
- TF_ASSIGN_OR_RETURN(index_vector_[i], gather_indices_.GetIntegralAsS64(
- index_vector_index_));
+ TF_ASSIGN_OR_RETURN(index_vector_[i],
+ start_indices_.GetIntegralAsS64(index_vector_index_));
}
return Status::OK();
}
@@ -708,15 +704,15 @@ class OutputGatherIndexToInputIndex {
// PropagateIndexVectorToInputIndex.
std::vector<int64> input_dim_value_to_index_vector_;
- // output_dim_is_gather_dims_[i] is true iff the output index i is a gather
+ // output_dim_is_batch_dims_[i] is true iff the output index i is a gather
// dimension.
- std::vector<bool> output_dim_is_gather_dims_;
+ std::vector<bool> output_dim_is_batch_dims_;
- // The buffer into which we construct an index into gather_indices_ to fetch
+ // The buffer into which we construct an index into start_indices_ to fetch
// the index vector.
std::vector<int64> index_vector_index_;
- // The index vector fetched from gather_indices_.
+ // The index vector fetched from start_indices_.
std::vector<int64> index_vector_;
// The result computed by this functor. operator() returns an ArraySlice into
@@ -724,24 +720,23 @@ class OutputGatherIndexToInputIndex {
std::vector<int64> input_index_;
const GatherDimensionNumbers& dim_numbers_;
- const Literal& gather_indices_;
+ const Literal& start_indices_;
};
-// This functor computes the contribution of the window indices in an output
+// This functor computes the contribution of the offset indices in an output
// index to an input index. That is, given an output index I it picks out the
-// output window indices in I and expands it into a window index into the input
-// shape.
-class OutputWindowIndexToInputIndex {
+// output offset indices in I and expands it into an index into the input shape.
+class OutputOffsetIndexToInputIndex {
public:
// The constructor does some setup work that is amortized across all
// iterations.
- explicit OutputWindowIndexToInputIndex(
+ explicit OutputOffsetIndexToInputIndex(
const GatherDimensionNumbers& dim_numbers, const Shape& input_shape,
const Shape& output_shape) {
std::vector<int64> window_index_to_output_index;
int64 output_index_count = 0;
for (int64 i = 0; i < output_shape.dimensions_size(); i++) {
- if (c_binary_search(dim_numbers.output_window_dims(), i)) {
+ if (c_binary_search(dim_numbers.offset_dims(), i)) {
window_index_to_output_index.push_back(output_index_count++);
} else {
output_index_count++;
@@ -750,7 +745,7 @@ class OutputWindowIndexToInputIndex {
int64 window_dim_count = 0;
for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
- if (c_binary_search(dim_numbers.elided_window_dims(), i)) {
+ if (c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
input_dim_value_to_output_index_.push_back(-1);
} else {
input_dim_value_to_output_index_.push_back(
@@ -808,20 +803,20 @@ class OutputWindowIndexToInputIndex {
// Rehapes the gather indices input to have a trailing degenerate `1` dimension
// if necessary. Hands over the ownership of the newly created literal (if
-// there is one) to `reshaped_gather_indices`.
+// there is one) to `reshaped_start_indices`.
static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
- int64 index_vector_dim, const Literal& gather_indices,
- std::unique_ptr<Literal>* reshaped_gather_indices) {
- if (gather_indices.shape().dimensions_size() != index_vector_dim) {
- return std::cref(gather_indices);
+ int64 index_vector_dim, const Literal& start_indices,
+ std::unique_ptr<Literal>* reshaped_start_indices) {
+ if (start_indices.shape().dimensions_size() != index_vector_dim) {
+ return std::cref(start_indices);
}
- std::vector<int64> new_shape(gather_indices.shape().dimensions().begin(),
- gather_indices.shape().dimensions().end());
+ std::vector<int64> new_shape(start_indices.shape().dimensions().begin(),
+ start_indices.shape().dimensions().end());
new_shape.push_back(1);
- TF_ASSIGN_OR_RETURN(*reshaped_gather_indices,
- gather_indices.Reshape(new_shape));
- return std::cref(**reshaped_gather_indices);
+ TF_ASSIGN_OR_RETURN(*reshaped_start_indices,
+ start_indices.Reshape(new_shape));
+ return std::cref(**reshaped_start_indices);
}
Status HloEvaluator::HandleGather(HloInstruction* gather) {
@@ -830,34 +825,33 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
const GatherDimensionNumbers& dim_numbers =
gather->gather_dimension_numbers();
const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0));
- std::unique_ptr<Literal> reshaped_gather_indices;
+ std::unique_ptr<Literal> reshaped_start_indices;
TF_ASSIGN_OR_RETURN(
- const Literal& gather_indices,
+ const Literal& start_indices,
ReshapedGatherIndices(dim_numbers.index_vector_dim(),
GetEvaluatedLiteralFor(gather->operand(1)),
- &reshaped_gather_indices));
+ &reshaped_start_indices));
// We iterate over the gather dimensions in the output shape in an outer loop
// nest, and iterate over the window dimensions in the output shape in an
// inner loop nest.
- ShapeUtil::IndexIterationSpace gather_indices_iteration_space =
- IterationSpaceForOutputGatherIndices(shape, dim_numbers);
- ShapeUtil::IndexIterationSpace window_indices_iteration_space =
- IterationSpaceForOutputWindowIndices(
- shape.dimensions_size(), gather->gather_window_bounds(), dim_numbers);
+ ShapeUtil::IndexIterationSpace start_indices_iteration_space =
+ IterationSpaceForOutputBatchIndices(shape, dim_numbers);
+ ShapeUtil::IndexIterationSpace offset_indices_iteration_space =
+ IterationSpaceForOutputOffsetIndices(
+ shape.dimensions_size(), gather->gather_slice_sizes(), dim_numbers);
// Scratch buffers that hold an index in the output shape and the
// corresponding index in the input shape.
std::vector<int64> input_index(operand.shape().dimensions_size());
std::vector<int64> output_index(gather->shape().dimensions_size());
- std::vector<int64> input_gather_index_clamped(
- operand.shape().dimensions_size());
+ std::vector<int64> input_index_clamped(operand.shape().dimensions_size());
- OutputGatherIndexToInputIndex output_gather_index_to_input_index(
+ OutputBatchIndexToInputIndex output_batch_index_to_input_index(
&gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
- /*output_shape=*/shape, &gather_indices);
- OutputWindowIndexToInputIndex output_window_index_to_input_index(
+ /*output_shape=*/shape, &start_indices);
+ OutputOffsetIndexToInputIndex output_offset_index_to_input_index(
gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
/*output_shape=*/shape);
@@ -869,29 +863,29 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
ArraySlice<int64> output_gather_index) -> StatusOr<bool> {
TF_ASSIGN_OR_RETURN(
ArraySlice<int64> input_window_index,
- output_window_index_to_input_index(output_window_index));
+ output_offset_index_to_input_index(output_window_index));
for (int i = 0, e = output_index.size(); i < e; i++) {
output_index[i] = output_gather_index[i] + output_window_index[i];
DCHECK_LT(output_index[i], shape.dimensions(i));
}
for (int i = 0, e = input_gather_index.size(); i < e; i++) {
int64 output_dim =
- output_window_index_to_input_index.input_dim_value_to_output_index(i);
+ output_offset_index_to_input_index.input_dim_value_to_output_index(i);
// If 'output_dim' is -1, it means 'i' is an elided window dim. This means
// we set the iteration index to 0, so for the purpose of the following
// calculations we can consider the output dimension size to be 1.
int64 output_dim_size =
output_dim == -1 ? 1 : shape.dimensions(output_dim);
// Clamp the gather index so that the gather region fits in the operand.
- // input_gather_index_clamped[i] = clamp(input_gather_index[i], 0,
+ // input_index_clamped[i] = clamp(input_gather_index[i], 0,
// operand_shape.dimensions(i) -
// output_dim_size);
- input_gather_index_clamped[i] =
+ input_index_clamped[i] =
std::min(operand_shape.dimensions(i) - output_dim_size,
std::max(0LL, input_gather_index[i]));
}
for (int i = 0, e = input_index.size(); i < e; i++) {
- input_index[i] = input_gather_index_clamped[i] + input_window_index[i];
+ input_index[i] = input_index_clamped[i] + input_window_index[i];
DCHECK_GE(input_index[i], 0);
DCHECK_LT(input_index[i], operand_shape.dimensions(i));
}
@@ -902,18 +896,17 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
auto gather_outer_loop_body =
[&](ArraySlice<int64> output_gather_index) -> StatusOr<bool> {
- TF_ASSIGN_OR_RETURN(
- ArraySlice<int64> input_gather_index,
- output_gather_index_to_input_index(output_gather_index));
+ TF_ASSIGN_OR_RETURN(ArraySlice<int64> input_gather_index,
+ output_batch_index_to_input_index(output_gather_index));
TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
- shape, window_indices_iteration_space,
+ shape, offset_indices_iteration_space,
std::bind(gather_inner_loop_body, std::placeholders::_1,
input_gather_index, output_gather_index)));
return true;
};
TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
- shape, gather_indices_iteration_space, gather_outer_loop_body));
+ shape, start_indices_iteration_space, gather_outer_loop_body));
evaluated_[gather] = std::move(result);
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 1cef3549e0..1394be68e4 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -1826,21 +1826,20 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[2,3] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1, 3}
+ slice_sizes={1, 3}
}
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
EXPECT_TRUE(LiteralTestUtil::Equal(
*LiteralUtil::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) {
@@ -1851,21 +1850,20 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[3,2] gather(operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
}
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
EXPECT_TRUE(LiteralTestUtil::Equal(
*LiteralUtil::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) {
@@ -1876,22 +1874,22 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,3,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
}
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
EXPECT_TRUE(LiteralTestUtil::Equal(
*LiteralUtil::CreateR3<int32>(
{{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) {
@@ -1902,11 +1900,11 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
}
)";
ParseAndVerifyModule(hlo_text);
@@ -1914,11 +1912,11 @@ ENTRY main {
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
EXPECT_TRUE(
LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{-1, 1}, {-4, 4}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest,
@@ -1930,11 +1928,11 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
}
)";
ParseAndVerifyModule(hlo_text);
@@ -1942,11 +1940,11 @@ ENTRY main {
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
EXPECT_TRUE(
LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{-2, 2}, {-1, 1}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) {
@@ -1957,21 +1955,20 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[1,1] gather(operand, indices),
- output_window_dims={0,1},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={0,1},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
}
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({1, 1});
EXPECT_TRUE(
LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{5}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) {
@@ -1982,21 +1979,21 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,1,1] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
}
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
EXPECT_TRUE(
LiteralTestUtil::Equal(*LiteralUtil::CreateR3<int32>({{{8}}, {{5}}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) {
@@ -2007,20 +2004,19 @@ ENTRY main {
operand = s32[3,0] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[2,0] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1, 0}
+ slice_sizes={1, 0}
}
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
EXPECT_TRUE(
LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{}, {}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) {
@@ -2031,21 +2027,21 @@ ENTRY main {
operand = s32[3] parameter(0)
indices = s32[2,2,1] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
- output_window_dims={},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1}
+ slice_sizes={1}
}
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
EXPECT_TRUE(
LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{0, 1}, {2, 1}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 4aaef1941b..57e75cf931 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -392,13 +392,12 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
<< "Gather instruction should have GatherDimensionNumbers set.";
std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers =
MakeUnique<GatherDimensionNumbers>(proto.gather_dimension_numbers());
- std::vector<int64> gather_window_bounds;
- for (int64 bound : proto.gather_window_bounds()) {
- gather_window_bounds.push_back(bound);
+ std::vector<int64> gather_slice_sizes;
+ for (int64 bound : proto.gather_slice_sizes()) {
+ gather_slice_sizes.push_back(bound);
}
- instruction =
- CreateGather(proto.shape(), operands(0), operands(1),
- *gather_dimension_numbers, gather_window_bounds);
+ instruction = CreateGather(proto.shape(), operands(0), operands(1),
+ *gather_dimension_numbers, gather_slice_sizes);
break;
}
case HloOpcode::kScatter: {
@@ -1078,11 +1077,11 @@ bool HloInstruction::HasSideEffect() const {
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateGather(
- const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices,
+ const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds) {
- return MakeUnique<HloGatherInstruction>(shape, operand, gather_indices,
- gather_dim_numbers, window_bounds);
+ tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ return MakeUnique<HloGatherInstruction>(shape, operand, start_indices,
+ gather_dim_numbers, slice_sizes);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateScatter(
@@ -3226,9 +3225,8 @@ const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const {
return Cast<HloGatherInstruction>(this)->gather_dimension_numbers();
}
-tensorflow::gtl::ArraySlice<int64> HloInstruction::gather_window_bounds()
- const {
- return Cast<HloGatherInstruction>(this)->gather_window_bounds();
+tensorflow::gtl::ArraySlice<int64> HloInstruction::gather_slice_sizes() const {
+ return Cast<HloGatherInstruction>(this)->gather_slice_sizes();
}
const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers()
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index b3eee90099..8d8f149ee3 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -667,9 +667,9 @@ class HloInstruction {
static std::unique_ptr<HloInstruction> CreateGather(
const Shape& shape, HloInstruction* operand,
- HloInstruction* gather_indices,
+ HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
static std::unique_ptr<HloInstruction> CreateScatter(
const Shape& shape, HloInstruction* operand,
@@ -1489,8 +1489,8 @@ class HloInstruction {
// Delegates to HloGatherInstruction::gather_dimension_numbers.
const GatherDimensionNumbers& gather_dimension_numbers() const;
- // Delegates to HloGatherInstruction::gather_window_bounds.
- tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const;
+ // Delegates to HloGatherInstruction::gather_slice_sizes.
+ tensorflow::gtl::ArraySlice<int64> gather_slice_sizes() const;
// Delegates to HloScatterInstruction::scatter_dimension_numbers().
const ScatterDimensionNumbers& scatter_dimension_numbers() const;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 8a694dde80..504b13043f 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -1355,7 +1355,7 @@ TEST_F(HloInstructionTest, Stringification) {
TEST_F(HloInstructionTest, StringifyGather_0) {
Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
- Shape gather_indices_tensor_shape =
+ Shape start_indices_tensor_shape =
ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
Shape gather_result_shape =
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26});
@@ -1363,19 +1363,18 @@ TEST_F(HloInstructionTest, StringifyGather_0) {
HloComputation::Builder builder("Gather");
HloInstruction* input = builder.AddInstruction(
HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
- HloInstruction* gather_indices =
+ HloInstruction* start_indices =
builder.AddInstruction(HloInstruction::CreateParameter(
- 1, gather_indices_tensor_shape, "gather_indices"));
-
- HloInstruction* gather_instruction =
- builder.AddInstruction(HloInstruction::CreateGather(
- gather_result_shape, input, gather_indices,
- HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
- /*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26}));
+ 1, start_indices_tensor_shape, "start_indices"));
+
+ HloInstruction* gather_instruction = builder.AddInstruction(
+ HloInstruction::CreateGather(gather_result_shape, input, start_indices,
+ HloGatherInstruction::MakeGatherDimNumbers(
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4),
+ /*slice_sizes=*/{30, 29, 28, 27, 26}));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@@ -1383,15 +1382,15 @@ TEST_F(HloInstructionTest, StringifyGather_0) {
EXPECT_EQ(gather_instruction->ToString(),
"%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
"gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
- "s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), "
- "output_window_dims={4,5,6,7,8}, elided_window_dims={}, "
- "gather_dims_to_operand_dims={0,1,2,3,4}, "
- "index_vector_dim=4, window_bounds={30,29,28,27,26}");
+ "s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), "
+ "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, "
+ "start_index_map={0,1,2,3,4}, "
+ "index_vector_dim=4, slice_sizes={30,29,28,27,26}");
}
TEST_F(HloInstructionTest, StringifyGather_1) {
Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
- Shape gather_indices_tensor_shape =
+ Shape start_indices_tensor_shape =
ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
Shape gather_result_shape =
ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26});
@@ -1399,19 +1398,18 @@ TEST_F(HloInstructionTest, StringifyGather_1) {
HloComputation::Builder builder("Gather");
HloInstruction* input = builder.AddInstruction(
HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
- HloInstruction* gather_indices =
+ HloInstruction* start_indices =
builder.AddInstruction(HloInstruction::CreateParameter(
- 1, gather_indices_tensor_shape, "gather_indices"));
-
- HloInstruction* gather_instruction =
- builder.AddInstruction(HloInstruction::CreateGather(
- gather_result_shape, input, gather_indices,
- HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
- /*index_vector_dim=*/2),
- /*window_bounds=*/{30, 29, 28, 27, 26}));
+ 1, start_indices_tensor_shape, "start_indices"));
+
+ HloInstruction* gather_instruction = builder.AddInstruction(
+ HloInstruction::CreateGather(gather_result_shape, input, start_indices,
+ HloGatherInstruction::MakeGatherDimNumbers(
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/2),
+ /*slice_sizes=*/{30, 29, 28, 27, 26}));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@@ -1419,10 +1417,10 @@ TEST_F(HloInstructionTest, StringifyGather_1) {
EXPECT_EQ(gather_instruction->ToString(),
"%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
"gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
- "s64[10,9,5,7,6]{4,3,2,1,0} %gather_indices), "
- "output_window_dims={4,5,6,7,8}, elided_window_dims={}, "
- "gather_dims_to_operand_dims={0,1,2,3,4}, "
- "index_vector_dim=2, window_bounds={30,29,28,27,26}");
+ "s64[10,9,5,7,6]{4,3,2,1,0} %start_indices), "
+ "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, "
+ "start_index_map={0,1,2,3,4}, "
+ "index_vector_dim=2, slice_sizes={30,29,28,27,26}");
}
TEST_F(HloInstructionTest, StringifyScatter) {
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 233cdda7b0..4fdf4360e6 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -1965,51 +1965,50 @@ HloDynamicSliceInstruction::CloneWithNewOperandsImpl(
}
HloGatherInstruction::HloGatherInstruction(
- const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices,
+ const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds)
+ tensorflow::gtl::ArraySlice<int64> slice_sizes)
: HloInstruction(HloOpcode::kGather, shape) {
AppendOperand(operand);
- AppendOperand(gather_indices);
+ AppendOperand(start_indices);
gather_dimension_numbers_ =
MakeUnique<GatherDimensionNumbers>(gather_dim_numbers);
- c_copy(window_bounds, std::back_inserter(gather_window_bounds_));
+ c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_));
}
string HloGatherInstruction::GatherDimensionNumbersToString() const {
CHECK(gather_dimension_numbers_ != nullptr);
- string output_window_dims =
- StrCat("output_window_dims={",
- Join(gather_dimension_numbers_->output_window_dims(), ","), "}");
- string elided_window_dims =
- StrCat("elided_window_dims={",
- Join(gather_dimension_numbers_->elided_window_dims(), ","), "}");
- string gather_dims_to_operand_dims = StrCat(
- "gather_dims_to_operand_dims={",
- Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}");
+ string offset_dims =
+ StrCat("offset_dims={",
+ Join(gather_dimension_numbers_->offset_dims(), ","), "}");
+ string collapsed_slice_dims =
+ StrCat("collapsed_slice_dims={",
+ Join(gather_dimension_numbers_->collapsed_slice_dims(), ","), "}");
+ string start_index_map =
+ StrCat("start_index_map={",
+ Join(gather_dimension_numbers_->start_index_map(), ","), "}");
string index_vector_dim = StrCat(
"index_vector_dim=", gather_dimension_numbers_->index_vector_dim());
return Join<std::initializer_list<string>>(
- {output_window_dims, elided_window_dims, gather_dims_to_operand_dims,
- index_vector_dim},
+ {offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim},
", ");
}
/* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers(
- tensorflow::gtl::ArraySlice<int64> output_window_dims,
- tensorflow::gtl::ArraySlice<int64> elided_window_dims,
- tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims,
+ tensorflow::gtl::ArraySlice<int64> offset_dims,
+ tensorflow::gtl::ArraySlice<int64> collapsed_slice_dims,
+ tensorflow::gtl::ArraySlice<int64> start_index_map,
int64 index_vector_dim) {
GatherDimensionNumbers gather_dim_numbers;
- for (int64 output_window_dim : output_window_dims) {
- gather_dim_numbers.add_output_window_dims(output_window_dim);
+ for (int64 output_window_dim : offset_dims) {
+ gather_dim_numbers.add_offset_dims(output_window_dim);
}
- for (int64 elided_window_dim : elided_window_dims) {
- gather_dim_numbers.add_elided_window_dims(elided_window_dim);
+ for (int64 elided_window_dim : collapsed_slice_dims) {
+ gather_dim_numbers.add_collapsed_slice_dims(elided_window_dim);
}
- for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) {
- gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim);
+ for (int64 gather_dim_to_input_dim : start_index_map) {
+ gather_dim_numbers.add_start_index_map(gather_dim_to_input_dim);
}
gather_dim_numbers.set_index_vector_dim(index_vector_dim);
@@ -2019,8 +2018,8 @@ string HloGatherInstruction::GatherDimensionNumbersToString() const {
HloInstructionProto HloGatherInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
*proto.mutable_gather_dimension_numbers() = gather_dimension_numbers();
- for (int64 bound : gather_window_bounds()) {
- proto.add_gather_window_bounds(bound);
+ for (int64 bound : gather_slice_sizes()) {
+ proto.add_gather_slice_sizes(bound);
}
return proto;
}
@@ -2028,7 +2027,7 @@ HloInstructionProto HloGatherInstruction::ToProto() const {
std::vector<string> HloGatherInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
return {GatherDimensionNumbersToString(),
- StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")};
+ StrCat("slice_sizes={", Join(gather_slice_sizes(), ","), "}")};
}
bool HloGatherInstruction::IdenticalSlowPath(
@@ -2039,7 +2038,7 @@ bool HloGatherInstruction::IdenticalSlowPath(
return protobuf_util::ProtobufEquals(
gather_dimension_numbers(),
casted_other.gather_dimension_numbers()) &&
- gather_window_bounds() == casted_other.gather_window_bounds();
+ gather_slice_sizes() == casted_other.gather_slice_sizes();
}
std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl(
@@ -2049,7 +2048,7 @@ std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl(
CHECK_EQ(new_operands.size(), 2);
return MakeUnique<HloGatherInstruction>(
shape, new_operands[0], new_operands[1], gather_dimension_numbers(),
- gather_window_bounds());
+ gather_slice_sizes());
}
HloScatterInstruction::HloScatterInstruction(
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 546949bc72..803dbeabeb 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -1212,15 +1212,15 @@ class HloGatherInstruction : public HloInstruction {
public:
explicit HloGatherInstruction(
const Shape& shape, HloInstruction* operand,
- HloInstruction* gather_indices,
+ HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
const GatherDimensionNumbers& gather_dimension_numbers() const {
CHECK(gather_dimension_numbers_ != nullptr);
return *gather_dimension_numbers_;
}
- tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const {
- return gather_window_bounds_;
+ tensorflow::gtl::ArraySlice<int64> gather_slice_sizes() const {
+ return gather_slice_sizes_;
}
// Returns the dump string of the gather dimension numbers.
string GatherDimensionNumbersToString() const;
@@ -1229,9 +1229,9 @@ class HloGatherInstruction : public HloInstruction {
// Creates an instance of GatherDimensionNumbers.
static GatherDimensionNumbers MakeGatherDimNumbers(
- tensorflow::gtl::ArraySlice<int64> output_window_dims,
- tensorflow::gtl::ArraySlice<int64> elided_window_dims,
- tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims,
+ tensorflow::gtl::ArraySlice<int64> offset_dims,
+ tensorflow::gtl::ArraySlice<int64> collapsed_slice_dims,
+ tensorflow::gtl::ArraySlice<int64> start_index_map,
int64 index_vector_dim);
private:
@@ -1247,7 +1247,7 @@ class HloGatherInstruction : public HloInstruction {
HloCloneContext* context) const override;
std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
- std::vector<int64> gather_window_bounds_;
+ std::vector<int64> gather_slice_sizes_;
};
class HloScatterInstruction : public HloInstruction {
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index eb48337cd7..ab57a8b07f 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -1233,22 +1233,21 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kGather: {
- optional<std::vector<tensorflow::int64>> output_window_dims;
- attrs["output_window_dims"] = {
- /*required=*/true, AttrTy::kBracedInt64List, &output_window_dims};
- optional<std::vector<tensorflow::int64>> elided_window_dims;
- attrs["elided_window_dims"] = {
- /*required=*/true, AttrTy::kBracedInt64List, &elided_window_dims};
- optional<std::vector<tensorflow::int64>> gather_dims_to_operand_dims;
- attrs["gather_dims_to_operand_dims"] = {/*required=*/true,
- AttrTy::kBracedInt64List,
- &gather_dims_to_operand_dims};
+ optional<std::vector<tensorflow::int64>> offset_dims;
+ attrs["offset_dims"] = {/*required=*/true, AttrTy::kBracedInt64List,
+ &offset_dims};
+ optional<std::vector<tensorflow::int64>> collapsed_slice_dims;
+ attrs["collapsed_slice_dims"] = {
+ /*required=*/true, AttrTy::kBracedInt64List, &collapsed_slice_dims};
+ optional<std::vector<tensorflow::int64>> start_index_map;
+ attrs["start_index_map"] = {/*required=*/true, AttrTy::kBracedInt64List,
+ &start_index_map};
optional<tensorflow::int64> index_vector_dim;
attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
&index_vector_dim};
- optional<std::vector<tensorflow::int64>> window_bounds;
- attrs["window_bounds"] = {/*required=*/true, AttrTy::kBracedInt64List,
- &window_bounds};
+ optional<std::vector<tensorflow::int64>> slice_sizes;
+ attrs["slice_sizes"] = {/*required=*/true, AttrTy::kBracedInt64List,
+ &slice_sizes};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
@@ -1257,14 +1256,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
GatherDimensionNumbers dim_numbers =
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/*output_window_dims,
- /*elided_window_dims=*/*elided_window_dims,
- /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims,
+ /*offset_dims=*/*offset_dims,
+ /*collapsed_slice_dims=*/*collapsed_slice_dims,
+ /*start_index_map=*/*start_index_map,
/*index_vector_dim=*/*index_vector_dim);
instruction = builder->AddInstruction(HloInstruction::CreateGather(
- shape, /*operand=*/operands[0], /*gather_indices=*/operands[1],
- dim_numbers, *window_bounds));
+ shape, /*operand=*/operands[0], /*start_indices=*/operands[1],
+ dim_numbers, *slice_sizes));
break;
}
case HloOpcode::kScatter: {
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 6fa3c63d83..0d7919346b 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -752,10 +752,10 @@ ENTRY %sparse_f32_r1 () -> f32[9] {
"gather",
R"(HloModule StringifyGather
-ENTRY %Gather (input_tensor: f32[50,49,48,47,46], gather_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] {
+ENTRY %Gather (input_tensor: f32[50,49,48,47,46], start_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] {
%input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
- %gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
- ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26}
+ %start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
+ ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}
}
)"
@@ -1030,8 +1030,8 @@ R"(HloModule gather
ENTRY Gather {
input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
- gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
- ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26}
+ start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
+ ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}
}
)"
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 949a4d1110..ac1a663633 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -572,7 +572,7 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) {
gather,
ShapeInference::InferGatherShape(
gather->operand(0)->shape(), gather->operand(1)->shape(),
- gather->gather_dimension_numbers(), gather->gather_window_bounds()));
+ gather->gather_dimension_numbers(), gather->gather_slice_sizes()));
}
Status ShapeVerifier::HandleScatter(HloInstruction* scatter) {
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
index 3531b7223f..8d17c03afc 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
@@ -153,7 +153,7 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayFor(
TF_ASSIGN_OR_RETURN(
computed_array,
ComputeArrayForGather(instr->shape(), instr->gather_dimension_numbers(),
- instr->gather_window_bounds(),
+ instr->gather_slice_sizes(),
FindOrDie(cache_, instr->operand(0)),
FindOrDie(cache_, instr->operand(1))));
} else if (instr->opcode() == HloOpcode::kReshape) {
@@ -251,24 +251,23 @@ StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::FoldGatherOfGather(
StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather(
const Shape& shape, const GatherDimensionNumbers& dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds, Array* source,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes, Array* source,
Array* indices) {
if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) {
VLOG(3) << "ComputeArrayForGather: indices are not scalar";
return nullptr;
}
- CHECK_EQ(dim_numbers.gather_dims_to_operand_dims_size(), 1);
+ CHECK_EQ(dim_numbers.start_index_map_size(), 1);
- // We can also handle dim_numbers.elided_window_dims_size() == 0 here, should
- // it become relevant.
+ // We can also handle dim_numbers.collapsed_slice_dims_size() == 0 here,
+ // should it become relevant.
- if (dim_numbers.elided_window_dims_size() != 1 ||
- dim_numbers.elided_window_dims(0) !=
- dim_numbers.gather_dims_to_operand_dims(0)) {
+ if (dim_numbers.collapsed_slice_dims_size() != 1 ||
+ dim_numbers.collapsed_slice_dims(0) != dim_numbers.start_index_map(0)) {
VLOG(3) << "ComputeArrayForGather: gather operations must elide "
- "gather_dims_to_operand_dims[0] and "
- "gather_dims_to_operand_dims[0] only";
+ "start_index_map[0] and "
+ "start_index_map[0] only";
return nullptr;
}
@@ -277,21 +276,21 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather(
// arrays from an array of size [7,4,6]. We check that condition down below:
for (int64 i = 0, e = source->shape().dimensions_size(); i < e; i++) {
- if (i != dim_numbers.elided_window_dims(0) &&
- source->shape().dimensions(i) != window_bounds[i]) {
- VLOG(3) << "ComputeArrayForGather: window_bounds[" << i
+ if (i != dim_numbers.collapsed_slice_dims(0) &&
+ source->shape().dimensions(i) != slice_sizes[i]) {
+ VLOG(3) << "ComputeArrayForGather: slice_sizes[" << i
<< "] != source->shape().dimensions(" << i << ") -- "
- << source->shape().dimensions(i) << " vs. " << window_bounds[i]
- << " with dim_numbers.elided_window_dims(0) = "
- << dim_numbers.elided_window_dims(0);
+ << source->shape().dimensions(i) << " vs. " << slice_sizes[i]
+ << " with dim_numbers.collapsed_slice_dims(0) = "
+ << dim_numbers.collapsed_slice_dims(0);
return nullptr;
}
}
- int64 source_dim = dim_numbers.gather_dims_to_operand_dims(0);
+ int64 source_dim = dim_numbers.start_index_map(0);
std::vector<int64> output_dims;
for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) {
- if (!c_binary_search(dim_numbers.output_window_dims(), i)) {
+ if (!c_binary_search(dim_numbers.offset_dims(), i)) {
output_dims.push_back(i);
}
}
@@ -735,11 +734,11 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims(
// operand = s32[3,5,2] constant({...})
// indices = s32[7] parameter(0)
// gather = s32[3,2,7] gather(operand, indices),
- // output_window_dims={0,1},
- // elided_window_dims={1},
- // gather_dims_to_operand_dims={1},
+ // offset_dims={0,1},
+ // collapsed_slice_dims={1},
+ // start_index_map={1},
// index_vector_dim=1,
- // window_bounds={3,1,2}
+ // slice_sizes={3,1,2}
// reshape = s32[6,7] reshape(gather)
//
// In this case the gather maps to:
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index e923dc39f7..675eb31d26 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -265,7 +265,7 @@ class IndexedArrayAnalysis {
StatusOr<Array*> ComputeArrayForGather(
const Shape& shape, const GatherDimensionNumbers& dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds, Array* source,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes, Array* source,
Array* indices);
StatusOr<Array*> ComputeArrayForDotWithIndexedLhs(
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
index 5f4b42799b..97052edf7d 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
@@ -82,11 +82,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[5] parameter(1)
ROOT gather = s32[5,3] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,3}
+ slice_sizes={1,3}
}
)";
@@ -102,11 +102,11 @@ ENTRY main {
operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}})
indices = s32[5] parameter(0)
ROOT gather = s32[5,3] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,3}
+ slice_sizes={1,3}
}
)";
@@ -122,11 +122,11 @@ ENTRY main {
operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}})
indices = s32[5,2] parameter(0)
ROOT gather = s32[5] gather(operand, indices),
- output_window_dims={},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1}
+ slice_sizes={1,1}
}
)";
@@ -141,11 +141,11 @@ ENTRY main {
operand = s32[3,3,1] parameter(0)
indices = s32[5] parameter(1)
ROOT gather = s32[5,3] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,2},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0,2},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,3,1}
+ slice_sizes={1,3,1}
}
)";
@@ -160,11 +160,11 @@ ENTRY main {
operand = s32[3,3,1] parameter(0)
indices = s32[5] parameter(1)
ROOT gather = s32[5,2,3] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={2},
- gather_dims_to_operand_dims={0},
+ offset_dims={1,2},
+ collapsed_slice_dims={2},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={2,3,1}
+ slice_sizes={2,3,1}
}
)";
@@ -179,11 +179,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[5] parameter(1)
ROOT gather = s32[5,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,2}
+ slice_sizes={1,2}
}
)";
@@ -199,17 +199,17 @@ ENTRY main {
indices_a = s32[5] parameter(0)
indices_b = s32[2] parameter(1)
gather_a = s32[5,3] gather(operand, indices_a),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,3}
+ slice_sizes={1,3}
ROOT gather_b = s32[2,3] gather(gather_a, indices_b),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,3}
+ slice_sizes={1,3}
}
)";
@@ -228,17 +228,17 @@ ENTRY main {
indices_a = s32[5,7] parameter(1)
indices_b = s32[2] parameter(2)
gather_a = s32[5,3,7] gather(operand, indices_a),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3,1}
+ slice_sizes={3,1}
ROOT gather_b = s32[5,3,2] gather(gather_a, indices_b),
- output_window_dims={0,1},
- elided_window_dims={2},
- gather_dims_to_operand_dims={2},
+ offset_dims={0,1},
+ collapsed_slice_dims={2},
+ start_index_map={2},
index_vector_dim=1,
- window_bounds={5,3,1}
+ slice_sizes={5,3,1}
}
)";
@@ -256,17 +256,17 @@ ENTRY main {
indices_a = s32[2] parameter(1)
indices_b = s32[5,7] parameter(2)
gather_a = s32[2,6] gather(operand, indices_a),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,6}
+ slice_sizes={1,6}
ROOT gather_b = s32[5,6,7] gather(gather_a, indices_b),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1,6}
+ slice_sizes={1,6}
}
)";
@@ -284,17 +284,17 @@ ENTRY main {
indices_a = s32[5,7] parameter(1)
indices_b = s32[4,8] parameter(2)
gather_a = s32[5,3,7] gather(operand, indices_a),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3,1}
+ slice_sizes={3,1}
ROOT gather_b = s32[4,5,3,8] gather(gather_a, indices_b),
- output_window_dims={1,2},
- elided_window_dims={2},
- gather_dims_to_operand_dims={2},
+ offset_dims={1,2},
+ collapsed_slice_dims={2},
+ start_index_map={2},
index_vector_dim=2,
- window_bounds={5,3,1}
+ slice_sizes={5,3,1}
}
)";
@@ -312,11 +312,11 @@ ENTRY main {
operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}})
indices = s32[5] parameter(0)
gather = s32[5,4] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT reshape = s32[5,2,2] reshape(gather)
}
)";
@@ -333,11 +333,11 @@ ENTRY main {
operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}})
indices = s32[5,7] parameter(0)
gather = s32[5,4,7] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT reshape = s32[5,2,2,7] reshape(gather)
}
)";
@@ -358,11 +358,11 @@ ENTRY main {
{{1,2,3,4,5,6},{1,2,3,4,5,6}}})
indices = s32[5,7] parameter(0)
gather = s32[5,2,6,7] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1,2},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1,2,6}
+ slice_sizes={1,2,6}
ROOT reshape = s32[5,3,4,7] reshape(gather)
}
)";
@@ -381,11 +381,11 @@ ENTRY main {
{1,2,3,4,5,6},{1,2,3,4,5,6}})
indices = s32[1] parameter(0)
gather = s32[1,6] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,6}
+ slice_sizes={1,6}
ROOT reshape = s32[1,1,6] reshape(gather)
}
)";
@@ -408,14 +408,14 @@ ENTRY main {
operand = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 1, 2, 3 } })
i.0 = s64[1,3]{1,0} parameter(0)
- g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), output_window_dims={2},
- elided_window_dims={0}, gather_dims_to_operand_dims={0},
- index_vector_dim=2, window_bounds={1,3}
+ g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), offset_dims={2},
+ collapsed_slice_dims={0}, start_index_map={0},
+ index_vector_dim=2, slice_sizes={1,3}
i.1 = s64[1] parameter(1)
- g.1 = s32[1,1,3]{2,1,0} gather(g.0, i.1), output_window_dims={0,2},
- elided_window_dims={1}, gather_dims_to_operand_dims={1},
- index_vector_dim=1, window_bounds={1,1,3}
+ g.1 = s32[1,1,3]{2,1,0} gather(g.0, i.1), offset_dims={0,2},
+ collapsed_slice_dims={1}, start_index_map={1},
+ index_vector_dim=1, slice_sizes={1,1,3}
ROOT reshape = s32[1,3]{1,0} reshape(g.1)
}
@@ -441,11 +441,11 @@ ENTRY main {
operand = s32[1,6] constant(s32[1,6]{{1,2,3,4,5,6}})
indices = s32[1] parameter(0)
gather = s32[1,6] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,6}
+ slice_sizes={1,6}
ROOT reshape = s32[1,1,6] reshape(gather)
}
)";
@@ -469,11 +469,11 @@ ENTRY main {
{1,2,3,4,5,6},{1,2,3,4,5,6}}})
indices = s32[1] parameter(0)
gather = s32[1,1,6] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1,2},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={1,1,6}
+ slice_sizes={1,1,6}
ROOT reshape = s32[1,1,1,6] reshape(gather)
}
)";
@@ -500,11 +500,11 @@ ENTRY main {
{1,2,3,4,5,6},{1,2,3,4,5,6}})
indices = s32[1,5] parameter(0)
gather = s32[1,5,6] gather(operand, indices),
- output_window_dims={2},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={2},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1,6}
+ slice_sizes={1,6}
ROOT reshape = s32[1,1,5,6] reshape(gather)
}
)";
@@ -530,11 +530,11 @@ ENTRY main {
operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}})
indices = s32[5,6] parameter(0)
gather = s32[5,4,6] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT reshape = s32[5,2,2,2,3] reshape(gather)
}
)";
@@ -562,11 +562,11 @@ ENTRY main {
{{1,2},{3,4},{5,6},{7,8},{9,10}}})
indices = s32[7] parameter(0)
gather = s32[3,2,7] gather(operand, indices),
- output_window_dims={0,1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0,1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3,1,2}
+ slice_sizes={3,1,2}
ROOT reshape = s32[6,7] reshape(gather)
}
)";
@@ -594,11 +594,11 @@ ENTRY main {
{{1},{2},{3},{4}}})
indices = s32[5,6] parameter(0)
gather = s32[5,4,6,1] gather(operand, indices),
- output_window_dims={1,3},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1,3},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1,4,1}
+ slice_sizes={1,4,1}
ROOT reshape = s32[5,2,2,2,3,1] reshape(gather)
}
)";
@@ -623,11 +623,11 @@ ENTRY main {
operand = f32[3,4] constant(f32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
indices = s32[5] parameter(0)
gather = f32[5,4] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT tanh = f32[5,4] tanh(gather)
}
)";
@@ -650,11 +650,11 @@ ENTRY main {
constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
indices = s32[5] parameter(0)
gather = s32[5,4] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT add = s32[5,4] add(gather, constant_broadcasted)
}
)";
@@ -678,11 +678,11 @@ ENTRY main {
constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
indices = s32[5] parameter(0)
gather = s32[5,4] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT sub = s32[5,4] subtract(gather, constant_broadcasted)
}
)";
@@ -706,11 +706,11 @@ ENTRY main {
constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
indices = s32[5] parameter(0)
gather = s32[5,4] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT sub = s32[5,4] subtract(constant_broadcasted, gather)
}
)";
@@ -733,11 +733,11 @@ ENTRY main {
constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={1}
indices = s32[5] parameter(0)
gather = s32[5,4] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT add = s32[5,4] add(gather, constant_broadcasted)
}
)";
@@ -760,11 +760,11 @@ ENTRY main {
constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={0}
indices = s32[5] parameter(0)
gather = s32[5,4] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT add = s32[5,4] add(gather, constant_broadcasted)
}
)";
@@ -808,11 +808,11 @@ ENTRY main {
dot_rhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
indices = s32[5] parameter(0)
dot_lhs = s32[5,4] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
)";
@@ -835,11 +835,11 @@ ENTRY main {
dot_rhs_constant = s32[3,3] constant(s32[3,3]{{1,2,3},{4,5,6},{7,8,9}})
indices = s32[5] parameter(0)
dot_lhs = s32[3,5] gather(gather_operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3,1}
+ slice_sizes={3,1}
ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={0}, rhs_contracting_dims={0}
}
)";
@@ -863,11 +863,11 @@ ENTRY main {
dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
indices = s32[5] parameter(0)
dot_rhs = s32[3,5] gather(gather_operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3,1}
+ slice_sizes={3,1}
ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
)";
@@ -892,11 +892,11 @@ ENTRY main {
dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
indices = s32[5] parameter(0)
dot_rhs = s32[5,3] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,3}
+ slice_sizes={1,3}
ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={1}
}
)";
@@ -921,11 +921,11 @@ ENTRY main {
dot_lhs_constant = s32[2,2,3] constant(s32[2,2,3]{{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}})
indices = s32[4] parameter(0)
dot_rhs = s32[2,3,4] gather(gather_operand, indices),
- output_window_dims={0,1},
- elided_window_dims={2},
- gather_dims_to_operand_dims={2},
+ offset_dims={0,1},
+ collapsed_slice_dims={2},
+ start_index_map={2},
index_vector_dim=1,
- window_bounds={2,3,1}
+ slice_sizes={2,3,1}
ROOT dot = s32[2,2,4] dot(dot_lhs_constant, dot_rhs),
lhs_contracting_dims={2}, rhs_contracting_dims={1},
lhs_batch_dims={0}, rhs_batch_dims={0}
@@ -952,11 +952,11 @@ ENTRY main {
dot_rhs_constant = s32[2,3] constant(s32[2,3]{{1,2,3},{4,5,6}})
indices = s32[2] parameter(0)
dot_lhs = s32[3,2] gather(gather_operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3,1}
+ slice_sizes={3,1}
ROOT dot = s32[3,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
)";
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index ec5743a777..cc1ec1704e 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -2492,201 +2492,196 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
static Status ValidateGatherDimensionNumbers(
const Shape& input_shape,
- tensorflow::gtl::ArraySlice<int64> gather_indices_shape,
+ tensorflow::gtl::ArraySlice<int64> start_indices_shape,
const GatherDimensionNumbers& dim_numbers) {
- if (!c_is_sorted(dim_numbers.output_window_dims())) {
+ if (!c_is_sorted(dim_numbers.offset_dims())) {
return InvalidArgument(
"Output window dimensions in gather op must be ascending; got: %s.",
- Join(dim_numbers.output_window_dims(), ", ").c_str());
+ Join(dim_numbers.offset_dims(), ", ").c_str());
}
- if (c_adjacent_find(dim_numbers.output_window_dims()) !=
- dim_numbers.output_window_dims().end()) {
+ if (c_adjacent_find(dim_numbers.offset_dims()) !=
+ dim_numbers.offset_dims().end()) {
return InvalidArgument(
"Output window dimensions in gather op must not repeat; got: %s.",
- Join(dim_numbers.output_window_dims(), ", ").c_str());
+ Join(dim_numbers.offset_dims(), ", ").c_str());
}
- const int64 output_window_dim_count = dim_numbers.output_window_dims_size();
+ const int64 output_offset_dim_count = dim_numbers.offset_dims_size();
const int64 output_shape_rank =
- output_window_dim_count + gather_indices_shape.size() - 1;
+ output_offset_dim_count + start_indices_shape.size() - 1;
- for (int i = 0; i < dim_numbers.output_window_dims_size(); ++i) {
- int64 window_index = dim_numbers.output_window_dims(i);
- if (window_index < 0 || window_index >= output_shape_rank) {
+ for (int i = 0; i < dim_numbers.offset_dims_size(); ++i) {
+ int64 offset_dim = dim_numbers.offset_dims(i);
+ if (offset_dim < 0 || offset_dim >= output_shape_rank) {
return InvalidArgument(
- "Window index %d in gather op is out of bounds; got %lld, but should "
+ "Offset dimension %d in gather op is out of bounds; got %lld, but "
+ "should "
"have been in [0,%lld).",
- i, window_index, output_shape_rank);
+ i, offset_dim, output_shape_rank);
}
}
- if (dim_numbers.gather_dims_to_operand_dims_size() !=
- gather_indices_shape[dim_numbers.index_vector_dim()]) {
+ if (dim_numbers.start_index_map_size() !=
+ start_indices_shape[dim_numbers.index_vector_dim()]) {
return InvalidArgument(
- "Gather op has %d elements in gather_dims_to_operand_dims and the "
- "bound of dimension index_vector_dim=%lld of gather_indices is "
+ "Gather op has %d elements in start_index_map and the "
+ "bound of dimension index_vector_dim=%lld of start_indices is "
"%lld. These two numbers must be equal.",
- dim_numbers.gather_dims_to_operand_dims_size(),
- dim_numbers.index_vector_dim(),
- gather_indices_shape[dim_numbers.index_vector_dim()]);
+ dim_numbers.start_index_map_size(), dim_numbers.index_vector_dim(),
+ start_indices_shape[dim_numbers.index_vector_dim()]);
}
- for (int i = 0; i < dim_numbers.gather_dims_to_operand_dims_size(); i++) {
- int64 gather_dim_to_input_dim = dim_numbers.gather_dims_to_operand_dims(i);
- if (gather_dim_to_input_dim < 0 ||
- gather_dim_to_input_dim >= input_shape.dimensions_size()) {
+ for (int i = 0; i < dim_numbers.start_index_map_size(); i++) {
+ int64 operand_dim_for_start_index_i = dim_numbers.start_index_map(i);
+ if (operand_dim_for_start_index_i < 0 ||
+ operand_dim_for_start_index_i >= input_shape.dimensions_size()) {
return InvalidArgument(
- "Invalid gather_dims_to_operand_dims mapping; domain is [0, %d), "
- "got: %d->%lld.",
- input_shape.dimensions_size(), i, gather_dim_to_input_dim);
+ "Invalid start_index_map; domain is [0, %d), got: %d->%lld.",
+ input_shape.dimensions_size(), i, operand_dim_for_start_index_i);
}
}
- std::vector<int64> sorted_gather_dims_to_operand_dims(
- dim_numbers.gather_dims_to_operand_dims().begin(),
- dim_numbers.gather_dims_to_operand_dims().end());
+ std::vector<int64> sorted_start_index_map(
+ dim_numbers.start_index_map().begin(),
+ dim_numbers.start_index_map().end());
- c_sort(sorted_gather_dims_to_operand_dims);
+ c_sort(sorted_start_index_map);
- if (c_adjacent_find(sorted_gather_dims_to_operand_dims) !=
- sorted_gather_dims_to_operand_dims.end()) {
+ if (c_adjacent_find(sorted_start_index_map) != sorted_start_index_map.end()) {
return InvalidArgument(
- "Repeated dimensions are not allowed in gather_dims_to_operand_dims; "
+ "Repeated dimensions are not allowed in start_index_map; "
"got: %s.",
- Join(dim_numbers.gather_dims_to_operand_dims(), ", ").c_str());
+ Join(dim_numbers.start_index_map(), ", ").c_str());
}
- for (int64 elided_dim : dim_numbers.elided_window_dims()) {
- if (elided_dim < 0 || elided_dim >= input_shape.dimensions_size()) {
+ for (int64 collapsed_dim : dim_numbers.collapsed_slice_dims()) {
+ if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) {
return InvalidArgument(
- "Invalid elided_window_dims set in gather op; valid range is [0, "
+ "Invalid collapsed_slice_dims set in gather op; valid range is [0, "
"%d), got: %lld.",
- input_shape.dimensions_size(), elided_dim);
+ input_shape.dimensions_size(), collapsed_dim);
}
}
- if (!c_is_sorted(dim_numbers.elided_window_dims())) {
+ if (!c_is_sorted(dim_numbers.collapsed_slice_dims())) {
return InvalidArgument(
- "elided_window_dims in gather op must be sorted; got: %s",
- Join(dim_numbers.elided_window_dims(), ", ").c_str());
+ "collapsed_slice_dims in gather op must be sorted; got: %s",
+ Join(dim_numbers.collapsed_slice_dims(), ", ").c_str());
}
- if (c_adjacent_find(dim_numbers.elided_window_dims()) !=
- dim_numbers.elided_window_dims().end()) {
+ if (c_adjacent_find(dim_numbers.collapsed_slice_dims()) !=
+ dim_numbers.collapsed_slice_dims().end()) {
return InvalidArgument(
- "Repeated dimensions not allowed in elided_window_dims in gather op; "
+ "Repeated dimensions not allowed in collapsed_slice_dims in gather op; "
"got: %s.",
- Join(dim_numbers.elided_window_dims(), ", ").c_str());
+ Join(dim_numbers.collapsed_slice_dims(), ", ").c_str());
}
return Status::OK();
}
/*static*/ StatusOr<Shape> ShapeInference::InferGatherShape(
- const Shape& input_shape, const Shape& gather_indices_shape,
+ const Shape& input_shape, const Shape& start_indices_shape,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds) {
+ tensorflow::gtl::ArraySlice<int64> slice_sizes) {
TF_RETURN_IF_ERROR(
ExpectArray(input_shape, "input tensor operand gather op"));
TF_RETURN_IF_ERROR(
- ExpectArray(gather_indices_shape, "gather indices operand of gather op"));
+ ExpectArray(start_indices_shape, "gather indices operand of gather op"));
- if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) {
+ if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
return InvalidArgument(
"Gather indices parameter must be an integral tensor; got %s.",
- ShapeUtil::HumanString(gather_indices_shape).c_str());
+ ShapeUtil::HumanString(start_indices_shape).c_str());
}
// We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if
// index_vector_dim is rank(P). The bounds of this expanded shape is
- // stored in expanded_gather_indices_shape.
+ // stored in expanded_start_indices_shape.
- if (gather_indices_shape.dimensions_size() <
+ if (start_indices_shape.dimensions_size() <
gather_dim_numbers.index_vector_dim() ||
gather_dim_numbers.index_vector_dim() < 0) {
return InvalidArgument(
- "Gather index leaf dimension must be within [0, rank(gather_indices) + "
- "1). rank(gather_indices) is %d and gather index leaf dimension is "
+ "Gather index leaf dimension must be within [0, rank(start_indices) + "
+ "1). rank(start_indices) is %d and gather index leaf dimension is "
"%lld.",
- gather_indices_shape.dimensions_size(),
+ start_indices_shape.dimensions_size(),
gather_dim_numbers.index_vector_dim());
}
- std::vector<int64> expanded_gather_indices_shape;
- expanded_gather_indices_shape.reserve(gather_indices_shape.dimensions_size());
- c_copy(gather_indices_shape.dimensions(),
- std::back_inserter(expanded_gather_indices_shape));
- if (expanded_gather_indices_shape.size() ==
+ std::vector<int64> expanded_start_indices_shape;
+ expanded_start_indices_shape.reserve(start_indices_shape.dimensions_size());
+ c_copy(start_indices_shape.dimensions(),
+ std::back_inserter(expanded_start_indices_shape));
+ if (expanded_start_indices_shape.size() ==
gather_dim_numbers.index_vector_dim()) {
- expanded_gather_indices_shape.push_back(1);
+ expanded_start_indices_shape.push_back(1);
}
TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers(
- input_shape, expanded_gather_indices_shape, gather_dim_numbers));
+ input_shape, expanded_start_indices_shape, gather_dim_numbers));
- if (window_bounds.size() != input_shape.dimensions_size()) {
+ if (slice_sizes.size() != input_shape.dimensions_size()) {
return InvalidArgument(
- "Gather op must have one window bound for every input dimension; got: "
- "len(window_bounds)=%lu, input_shape.rank=%d.",
- window_bounds.size(), input_shape.dimensions_size());
+ "Gather op must have one slice size for every input dimension; got: "
+ "len(slice_sizes)=%lu, input_shape.rank=%d.",
+ slice_sizes.size(), input_shape.dimensions_size());
}
- if (window_bounds.size() !=
- gather_dim_numbers.output_window_dims_size() +
- gather_dim_numbers.elided_window_dims_size()) {
+ if (slice_sizes.size() !=
+ gather_dim_numbers.offset_dims_size() +
+ gather_dim_numbers.collapsed_slice_dims_size()) {
return InvalidArgument(
- "All components of the window index in a gather op must either be a "
- "output window index or explicitly elided; got len(window_bounds)=%lu, "
- "output_window_bounds=%s, elided_window_bounds=%s.",
- window_bounds.size(),
- Join(gather_dim_numbers.output_window_dims(), ",").c_str(),
- Join(gather_dim_numbers.elided_window_dims(), ",").c_str());
+ "All components of the offset index in a gather op must either be a "
+ "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, "
+ "output_slice_sizes=%s, collapsed_slice_dims=%s.",
+ slice_sizes.size(), Join(gather_dim_numbers.offset_dims(), ",").c_str(),
+ Join(gather_dim_numbers.collapsed_slice_dims(), ",").c_str());
}
- for (int i = 0; i < window_bounds.size(); i++) {
- int64 window_bound = window_bounds[i];
- int64 corresponding_input_bound = input_shape.dimensions(i);
- if (window_bound < 0 || window_bound > corresponding_input_bound) {
+ for (int i = 0; i < slice_sizes.size(); i++) {
+ int64 slice_size = slice_sizes[i];
+ int64 corresponding_input_size = input_shape.dimensions(i);
+ if (slice_size < 0 || slice_size > corresponding_input_size) {
return InvalidArgument(
- "Window bound at index %d in gather op is out of range, must be "
- "within "
- "[0, %lld), got %lld.",
- i, corresponding_input_bound + 1, window_bound);
+ "Slice size at index %d in gather op is out of range, must be "
+ "within [0, %lld), got %lld.",
+ i, corresponding_input_size + 1, slice_size);
}
}
- for (int i = 0; i < gather_dim_numbers.elided_window_dims_size(); i++) {
- if (window_bounds[gather_dim_numbers.elided_window_dims(i)] != 1) {
+ for (int i = 0; i < gather_dim_numbers.collapsed_slice_dims_size(); i++) {
+ if (slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)] != 1) {
return InvalidArgument(
- "Gather op can only elide window indices with bound 1, but bound is "
+ "Gather op can only collapse slice dims with bound 1, but bound is "
"%lld for index %lld at position %d.",
- window_bounds[gather_dim_numbers.elided_window_dims(i)],
- gather_dim_numbers.elided_window_dims(i), i);
+ slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)],
+ gather_dim_numbers.collapsed_slice_dims(i), i);
}
}
- int64 result_rank = gather_dim_numbers.output_window_dims_size() +
- (expanded_gather_indices_shape.size() - 1);
- int64 window_dims_seen = 0;
+ int64 result_rank = gather_dim_numbers.offset_dims_size() +
+ (expanded_start_indices_shape.size() - 1);
+ int64 offset_dims_seen = 0;
int64 gather_dims_seen = 0;
std::vector<int64> output_dim_bounds;
output_dim_bounds.reserve(result_rank);
for (int64 i = 0; i < result_rank; i++) {
int64 current_bound;
- bool is_window_index =
- c_binary_search(gather_dim_numbers.output_window_dims(), i);
+ bool is_window_index = c_binary_search(gather_dim_numbers.offset_dims(), i);
if (is_window_index) {
- while (c_binary_search(gather_dim_numbers.elided_window_dims(),
- window_dims_seen)) {
- window_dims_seen++;
+ while (c_binary_search(gather_dim_numbers.collapsed_slice_dims(),
+ offset_dims_seen)) {
+ offset_dims_seen++;
}
- current_bound = window_bounds[window_dims_seen++];
+ current_bound = slice_sizes[offset_dims_seen++];
} else {
if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) {
gather_dims_seen++;
}
- current_bound = expanded_gather_indices_shape[gather_dims_seen++];
+ current_bound = expanded_start_indices_shape[gather_dims_seen++];
}
output_dim_bounds.push_back(current_bound);
@@ -2837,25 +2832,25 @@ Status ValidateScatterDimensionNumbers(
scatter_dim_numbers));
int64 inserted_dims_seen = 0;
- std::vector<int64> max_update_window_bounds;
+ std::vector<int64> max_update_slice_sizes;
for (int i = 0; i < operand_shape.dimensions_size(); ++i) {
if (inserted_dims_seen < scatter_dim_numbers.inserted_window_dims_size() &&
scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) {
++inserted_dims_seen;
} else {
- max_update_window_bounds.push_back(operand_shape.dimensions(i));
+ max_update_slice_sizes.push_back(operand_shape.dimensions(i));
}
}
for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) {
auto update_window_dim = scatter_dim_numbers.update_window_dims(i);
if (updates_shape.dimensions(update_window_dim) >
- max_update_window_bounds[i]) {
+ max_update_slice_sizes[i]) {
return InvalidArgument(
"Bounds of the window dimensions of updates must not exceed the "
"bounds of the corresponding dimensions of operand. For dimension "
"%lld, updates bound is %lld, operand bound is %lld.",
update_window_dim, updates_shape.dimensions(update_window_dim),
- max_update_window_bounds[i]);
+ max_update_slice_sizes[i]);
}
}
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index bfd79a4433..4974ac9916 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -276,9 +276,9 @@ class ShapeInference {
// with the given input shape, gather indices shape and gather dimension
// numbers.
static StatusOr<Shape> InferGatherShape(
- const Shape& input_shape, const Shape& gather_indices_shape,
+ const Shape& input_shape, const Shape& start_indices_shape,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
// Helper that validates the given input shape, scatter indices shape, updates
// shape, and scatter dimension numbers that constitute a scatter operation,
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index a73fa181cd..4ed8fc6b86 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -1654,11 +1654,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGather) {
ShapeInference::InferGatherShape(
matrix_64_48_, s64_vector_32_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{0},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1},
+ /*offset_dims=*/{0},
+ /*collapsed_slice_dims=*/{1},
+ /*start_index_map=*/{1},
/*index_vector_dim=*/1),
- /*window_bounds=*/{64, 1}));
+ /*slice_sizes=*/{64, 1}));
EXPECT_TRUE(
ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32})))
<< ShapeUtil::HumanString(gather_shape);
@@ -1669,11 +1669,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherV2) {
ShapeInference::InferGatherShape(
matrix_64_48_, s64_vector_32_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{1},
- /*elided_window_dims=*/{0},
- /*gather_dims_to_operand_dims=*/{0},
+ /*offset_dims=*/{1},
+ /*collapsed_slice_dims=*/{0},
+ /*start_index_map=*/{0},
/*index_vector_dim=*/1),
- /*window_bounds=*/{1, 48}));
+ /*slice_sizes=*/{1, 48}));
EXPECT_TRUE(
ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48})))
<< ShapeUtil::HumanString(gather_shape);
@@ -1684,11 +1684,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherNd) {
ShapeInference::InferGatherShape(
matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4},
- /*elided_window_dims=*/{0},
- /*gather_dims_to_operand_dims=*/{0},
+ /*offset_dims=*/{4},
+ /*collapsed_slice_dims=*/{0},
+ /*start_index_map=*/{0},
/*index_vector_dim=*/4),
- /*window_bounds=*/{1, 48}));
+ /*slice_sizes=*/{1, 48}));
EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48})))
<< ShapeUtil::HumanString(gather_shape);
@@ -1700,11 +1700,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26}));
+ /*slice_sizes=*/{30, 29, 28, 27, 26}));
EXPECT_TRUE(ShapeUtil::Equal(
gather_shape,
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26})))
@@ -1717,11 +1717,11 @@ TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) {
ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/2),
- /*window_bounds=*/{30, 29, 28, 27, 26}));
+ /*slice_sizes=*/{30, 29, 28, 27, 26}));
EXPECT_TRUE(ShapeUtil::Equal(
gather_shape,
@@ -1735,11 +1735,11 @@ TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/0),
- /*window_bounds=*/{30, 29, 28, 27, 26}));
+ /*slice_sizes=*/{30, 29, 28, 27, 26}));
EXPECT_TRUE(ShapeUtil::Equal(
gather_shape,
@@ -1749,16 +1749,15 @@ TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
TEST_F(ScatterGatherShapeInferenceTest, NoOutputGatherDims) {
// This is equivalent to a dynamic slice.
- TF_ASSERT_OK_AND_ASSIGN(
- Shape gather_shape,
- ShapeInference::InferGatherShape(
- f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
- HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{0, 1, 2, 3, 4},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
- /*index_vector_dim=*/0),
- /*window_bounds=*/{30, 29, 28, 27, 26}));
+ TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
+ ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
+ HloGatherInstruction::MakeGatherDimNumbers(
+ /*offset_dims=*/{0, 1, 2, 3, 4},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/0),
+ /*slice_sizes=*/{30, 29, 28, 27, 26}));
EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26})))
@@ -1772,11 +1771,11 @@ TEST_F(ScatterGatherShapeInferenceTest, ScalarGatherIndices) {
ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{0, 1, 2, 3},
- /*elided_window_dims=*/{0},
- /*gather_dims_to_operand_dims=*/{0},
+ /*offset_dims=*/{0, 1, 2, 3},
+ /*collapsed_slice_dims=*/{0},
+ /*start_index_map=*/{0},
/*index_vector_dim=*/0),
- /*window_bounds=*/{1, 30, 29, 28, 27}));
+ /*slice_sizes=*/{1, 30, 29, 28, 27}));
EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
ShapeUtil::MakeShape(F32, {30, 29, 28, 27})))
@@ -1787,11 +1786,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TupleShapedTensorInput) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
tuple_shape_, s64_vector_32_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{0},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1},
+ /*offset_dims=*/{0},
+ /*collapsed_slice_dims=*/{1},
+ /*start_index_map=*/{1},
/*index_vector_dim=*/1),
- /*window_bounds=*/{64, 1});
+ /*slice_sizes=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
HasSubstr("Expected array argument for input"))
@@ -1802,11 +1801,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
s64_vector_32_, tuple_shape_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{0},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1},
+ /*offset_dims=*/{0},
+ /*collapsed_slice_dims=*/{1},
+ /*start_index_map=*/{1},
/*index_vector_dim=*/0),
- /*window_bounds=*/{64, 1});
+ /*slice_sizes=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
HasSubstr("Expected array argument for gather indices"))
@@ -1817,11 +1816,11 @@ TEST_F(ScatterGatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
s64_vector_32_, vector_32_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{0},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1},
+ /*offset_dims=*/{0},
+ /*collapsed_slice_dims=*/{1},
+ /*start_index_map=*/{1},
/*index_vector_dim=*/0),
- /*window_bounds=*/{64, 1});
+ /*slice_sizes=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
HasSubstr("Gather indices parameter must be an integral tensor"))
@@ -1833,11 +1832,11 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 8, 7},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 8, 7},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
statusor.status().error_message(),
@@ -1850,11 +1849,11 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 7},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 7},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
statusor.status().error_message(),
@@ -1867,14 +1866,14 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 99, 100, 101},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 99, 100, 101},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Window index 2 in gather op is out of bounds"))
+ HasSubstr("Offset dimension 2 in gather op is out of bounds"))
<< statusor.status();
}
@@ -1883,14 +1882,14 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 9},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 9},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Window index 4 in gather op is out of bounds"))
+ HasSubstr("Offset dimension 4 in gather op is out of bounds"))
<< statusor.status();
}
@@ -1899,16 +1898,16 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{4},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{4},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
statusor.status().error_message(),
- HasSubstr("All components of the window index in a gather op must either "
- "be a output window index or explicitly elided"))
+ HasSubstr("All components of the offset index in a gather op must either "
+ "be a offset dimension or explicitly collapsed"))
<< statusor.status();
}
@@ -1917,14 +1916,14 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{0, 1, 2, 3, 19},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{0, 1, 2, 3, 19},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Invalid elided_window_dims set in gather op; valid "
+ HasSubstr("Invalid collapsed_slice_dims set in gather op; valid "
"range is [0, 5), got: 19"))
<< statusor.status();
}
@@ -1934,16 +1933,15 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{0, 1, 2, 3, 3},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{0, 1, 2, 3, 3},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
- EXPECT_THAT(
- statusor.status().error_message(),
- HasSubstr(
- "Repeated dimensions not allowed in elided_window_dims in gather op"))
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Repeated dimensions not allowed in "
+ "collapsed_slice_dims in gather op"))
<< statusor.status();
}
@@ -1952,17 +1950,16 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
- EXPECT_THAT(
- statusor.status().error_message(),
- HasSubstr("Gather op has 4 elements in gather_dims_to_operand_dims and "
- "the bound of dimension index_vector_dim=4 of "
- "gather_indices is 5. These two numbers must be equal."))
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Gather op has 4 elements in start_index_map and "
+ "the bound of dimension index_vector_dim=4 of "
+ "start_indices is 5. These two numbers must be equal."))
<< statusor.status();
}
@@ -1971,16 +1968,14 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 7},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
- EXPECT_THAT(
- statusor.status().error_message(),
- HasSubstr("Invalid gather_dims_to_operand_dims mapping; domain is "
- "[0, 5), got: 4->7"))
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Invalid start_index_map; domain is [0, 5), got: 4->7"))
<< statusor.status();
}
@@ -1989,16 +1984,15 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 3},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
statusor.status().error_message(),
- HasSubstr(
- "Repeated dimensions are not allowed in gather_dims_to_operand_dims"))
+ HasSubstr("Repeated dimensions are not allowed in start_index_map"))
<< statusor.status();
}
@@ -2007,14 +2001,14 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{2, 1},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{2, 1},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{1, 1, 28, 27, 26});
+ /*slice_sizes=*/{1, 1, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("elided_window_dims in gather op must be sorted"))
+ HasSubstr("collapsed_slice_dims in gather op must be sorted"))
<< statusor.status();
}
@@ -2023,15 +2017,15 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7},
- /*elided_window_dims=*/{2},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7},
+ /*collapsed_slice_dims=*/{2},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 1, 300, 26});
+ /*slice_sizes=*/{30, 29, 1, 300, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Window bound at index 3 in gather op is out of range, "
- "must be within [0, 48), got 300"))
+ HasSubstr("Slice size at index 3 in gather op is out of range, "
+ "must be within [0, 48), got 300."))
<< statusor.status();
}
@@ -2040,16 +2034,15 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 26});
+ /*slice_sizes=*/{30, 29, 28, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
statusor.status().error_message(),
- HasSubstr(
- "Gather op must have one window bound for every input dimension"))
+ HasSubstr("Gather op must have one slice size for every input dimension"))
<< statusor.status();
}
@@ -2058,15 +2051,15 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7},
+ /*collapsed_slice_dims=*/{1},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 26, 20});
+ /*slice_sizes=*/{30, 29, 28, 26, 20});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Gather op can only elide window indices with bound 1, "
- "but bound is 29 for index 1 at position 0"))
+ HasSubstr("Gather op can only collapse slice dims with bound 1, "
+ "but bound is 29 for index 1 at position 0."))
<< statusor.status();
}
@@ -2074,16 +2067,16 @@ TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/32),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
HasSubstr("Gather index leaf dimension must be within [0, "
- "rank(gather_indices) + 1)"))
+ "rank(start_indices) + 1)"))
<< statusor.status();
}
diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc
index b77bece85a..f866ed6519 100644
--- a/tensorflow/compiler/xla/tests/gather_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc
@@ -30,8 +30,8 @@ using tensorflow::gtl::nullopt;
class GatherOperationTest : public HloTestBase {
protected:
void RunTest(const string& hlo_text, Literal* operand,
- Literal* gather_indices) {
- RunTest(hlo_text, {operand, gather_indices});
+ Literal* start_indices) {
+ RunTest(hlo_text, {operand, start_indices});
}
void RunTest(const string& hlo_text,
@@ -52,18 +52,17 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[2,3] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1, 3}
+ slice_sizes={1, 3}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherV2) {
@@ -74,18 +73,17 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[3,2] gather(operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherMultipleBatchDims) {
@@ -96,18 +94,18 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,3,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_0) {
@@ -118,18 +116,18 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2,2] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
- output_window_dims={},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=2,
- window_bounds={1, 1}
+ slice_sizes={1, 1}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_1) {
@@ -140,18 +138,18 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2,2] parameter(1)
ROOT gather = s32[2,1,1,2] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=2,
- window_bounds={1, 1}
+ slice_sizes={1, 1}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNd) {
@@ -162,20 +160,20 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdNonDefaultIndexVectorDim) {
@@ -186,20 +184,20 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, DynamicSlice) {
@@ -210,18 +208,17 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[1,1] gather(operand, indices),
- output_window_dims={0,1},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={0,1},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, BatchDynamicSlice) {
@@ -232,18 +229,18 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,1,1] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, ZeroDimBounds) {
@@ -254,17 +251,16 @@ ENTRY main {
operand = s32[3,0] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[2,0] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1, 0}
+ slice_sizes={1, 0}
}
)";
std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) {
@@ -278,19 +274,19 @@ ENTRY main {
operand = s32[3,3]{1,0} parameter(0)
indices = s32[6,2]{1,0} parameter(1)
gather = s32[6,1,1]{2,1,0} gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1}
+ slice_sizes={1,1}
ROOT result = s32[6]{0} reshape(gather)
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>(
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, OutOfBoundsUnsignedIndex) {
@@ -304,19 +300,19 @@ ENTRY main {
operand = s32[3,3]{1,0} parameter(0)
indices = u32[6,2]{1,0} parameter(1)
gather = s32[6,1,1]{2,1,0} gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1}
+ slice_sizes={1,1}
ROOT result = s32[6]{0} reshape(gather)
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<uint32>(
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<uint32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, NegativeIndex) {
@@ -330,19 +326,19 @@ ENTRY main {
operand = s32[3,3]{1,0} parameter(0)
indices = s32[6,2]{1,0} parameter(1)
gather = s32[6,1,1]{2,1,0} gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1}
+ slice_sizes={1,1}
ROOT result = s32[6]{0} reshape(gather)
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>(
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>(
{{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, NegativeIndexIntoUnsignedOperand) {
@@ -356,19 +352,19 @@ ENTRY main {
operand = u32[3,3]{1,0} parameter(0)
indices = s32[6,2]{1,0} parameter(1)
gather = u32[6,1,1]{2,1,0} gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1}
+ slice_sizes={1,1}
ROOT result = u32[6]{0} reshape(gather)
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<uint32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>(
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>(
{{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, OneScalarIndex) {
@@ -379,17 +375,17 @@ ENTRY main {
operand = s32[2,3,2]{2,1,0} parameter(0)
index = s32[] parameter(1)
ROOT gather = s32[1,3,2]{2,1,0} gather(operand, index),
- output_window_dims={0,1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0},
+ offset_dims={0,1,2},
+ collapsed_slice_dims={},
+ start_index_map={0},
index_vector_dim=0,
- window_bounds={1,3,2}
+ slice_sizes={1,3,2}
}
)";
std::unique_ptr<Literal> operand = LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR0<int32>(1);
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR0<int32>(1);
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, ScalarResult) {
@@ -400,16 +396,16 @@ ENTRY main {
operand = s32[4]{0} parameter(0)
index = s32[] parameter(1)
ROOT gather = s32[] gather(operand, index),
- output_window_dims={},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=0,
- window_bounds={1}
+ slice_sizes={1}
}
)";
std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR0<int32>(1);
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR0<int32>(1);
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, ZeroSizedResult) {
@@ -420,17 +416,17 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[0] parameter(1)
ROOT gather = s32[0,3] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1, 3}
+ slice_sizes={1, 3}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR1<int32>({});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({});
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherV2) {
@@ -441,11 +437,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
gather = s32[3,2] gather(operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
one = s32[] constant(1)
one_broadcasted = s32[3,2] broadcast(one), dimensions={}
ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted)
@@ -453,9 +449,8 @@ ENTRY main {
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherMultipleBatchDims) {
@@ -466,11 +461,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,3,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
one = s32[] constant(1)
one_broadcasted = s32[2,3,2] broadcast(one), dimensions={}
ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted)
@@ -478,9 +473,9 @@ ENTRY main {
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNdMultipleBatchDims) {
@@ -491,11 +486,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
- output_window_dims={},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=2,
- window_bounds={1, 1}
+ slice_sizes={1, 1}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
@@ -503,9 +498,9 @@ ENTRY main {
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNd) {
@@ -516,11 +511,11 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
@@ -530,9 +525,9 @@ ENTRY main {
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest,
@@ -544,11 +539,11 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
@@ -558,9 +553,9 @@ ENTRY main {
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, FusedDynamicSlice) {
@@ -571,11 +566,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
gather = s32[1,1] gather(operand, indices),
- output_window_dims={0,1},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={0,1},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
one = s32[] constant(1)
one_broadcasted = s32[1,1] broadcast(one), dimensions={}
ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted)
@@ -583,9 +578,8 @@ ENTRY main {
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, FusedBatchDynamicSlice) {
@@ -596,11 +590,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,1,1] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
one = s32[] constant(1)
one_broadcasted = s32[2,1,1] broadcast(one), dimensions={}
ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted)
@@ -608,9 +602,9 @@ ENTRY main {
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
class GatherClientLibraryTest : public ClientLibraryTestBase {};
@@ -622,11 +616,11 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
// operand = s32[3,3] parameter(0)
// indices = s32[2] parameter(1)
// ROOT gather = s32[2,3] gather(operand, indices),
- // output_window_dims={1},
- // elided_window_dims={0},
- // gather_dims_to_operand_dims={0},
+ // offset_dims={1},
+ // collapsed_slice_dims={0},
+ // start_index_map={0},
// index_vector_dim=1,
- // window_bounds={1, 3}
+ // slice_sizes={1, 3}
// }
XlaBuilder builder("gather_basic");
@@ -637,9 +631,9 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
auto operand = Parameter(&builder, 0, operand_shape, "operand");
auto indices = Parameter(&builder, 1, indices_shape, "indices");
GatherDimensionNumbers dim_numbers;
- dim_numbers.add_output_window_dims(1);
- dim_numbers.add_elided_window_dims(0);
- dim_numbers.add_gather_dims_to_operand_dims(0);
+ dim_numbers.add_offset_dims(1);
+ dim_numbers.add_collapsed_slice_dims(0);
+ dim_numbers.add_start_index_map(0);
dim_numbers.set_index_vector_dim(1);
Gather(operand, indices, dim_numbers, {1, 3});
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 4c35e93d38..27aa94c2cb 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -424,25 +424,25 @@ message GatherDimensionNumbers {
// "Window indices" is a term for a set of indices that index into the
// interior of a dynamic-slice from the input tensor, the starting indices for
// which were computed from output_gather_dims (see the operation semantic for
- // how this is defined) and the gather_indices tensor.
+ // how this is defined) and the start_indices tensor.
//
// The window indices for a specific output index Out is computed as:
//
// i = 0
// for (k : [0, input_tensor_shape.rank))
// window_indices[k] =
- // if k in elided_window_dims
+ // if k in collapsed_slice_dims
// then 0
- // else Out[output_window_dims[i++]]
- repeated int64 output_window_dims = 1;
- repeated int64 elided_window_dims = 2;
+ // else Out[offset_dims[i++]]
+ repeated int64 offset_dims = 1;
+ repeated int64 collapsed_slice_dims = 2;
- // This is interpreted as a map from i to gather_dims_to_operand_dims[i]. It
- // transforms the gather index looked up from the gather_indices tensor into
+ // This is interpreted as a map from i to start_index_map[i]. It
+ // transforms the gather index looked up from the start_indices tensor into
// the starting index in the input space.
- repeated int64 gather_dims_to_operand_dims = 3;
+ repeated int64 start_index_map = 3;
- // The dimension in the gather_indices input that contains the starting
+ // The dimension in the start_indices input that contains the starting
// indices.
int64 index_vector_dim = 4;
}
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
index 16dd3c5bf3..2de30d1b3d 100644
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ b/tensorflow/docs_src/performance/xla/operation_semantics.md
@@ -1138,7 +1138,7 @@ array with the same shape. It is allowed for `operand` to be a scalar (rank 0).
## Gather
The XLA gather operation stitches together several slices (each slice at a
-potentially different runtime offset) of an input tensor into an output tensor.
+potentially different runtime offset) of an input array.
### General Semantics
@@ -1146,151 +1146,141 @@ See also
[`XlaBuilder::Gather`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
For a more intuitive description, see the "Informal Description" section below.
-<b> `gather(operand, gather_indices, output_window_dims, elided_window_dims, window_bounds, gather_dims_to_operand_dims)` </b>
+<b> `gather(operand, start_indices, offset_dims, collapsed_slice_dims, slice_sizes, start_index_map)` </b>
|Arguments | Type | Semantics |
|----------------- | ----------------------- | --------------------------------|
-|`operand` | `XlaOp` | The tensor we’re gathering |
+|`operand` | `XlaOp` | The array we’re gathering |
: : : from. :
-|`gather_indices` | `XlaOp` | Tensor containing the starting |
-: : : indices of the slices we're :
-: : : stitching together into the :
-: : : output tensor. :
-|`index_vector_dim` | `int64` | The dimension in |
-: : : `gather_indices` that contains :
-: : : the starting indices. :
-|`output_window_dims` | `ArraySlice<int64>` | The set of dimensions in the |
-: : : output shape that are _window :
-: : : dimensions_ (defined below). :
-: : : Not all window dimensions may :
-: : : be present in the output shape. :
-|`elided_window_dims` | `ArraySlice<int64>` | The set of _window dimensions_ |
-: : : that are not present in the output shape. :
-: : : `window_bounds[i]` must be `1` for all `i` :
-: : : in `elided_window_dims`. :
-|`window_bounds` | `ArraySlice<int64>` | `window_bounds[i]` is the bounds |
-: : : for window dimension `i`. This includes :
-: : : both the window dimensions that are :
-: : : explicitly part of the output shape (via :
-: : : `output_window_dims`) and the window :
-: : : dimensions that are elided (via :
-: : : `elided_window_dims`). :
-|`gather_dims_to_operand_dims` | `ArraySlice<int64>` | A dimension map (the |
-: : : array is interpreted as mapping `i` to :
-: : : `gather_dims_to_operand_dims[i]`) from :
-: : : the gather indices in `gather_indices` to :
-: : : the operand index space. It has to be :
-: : : one-to-one and total. :
-
-For every index `Out` in the output tensor, we compute two things (more
-precisely described later):
-
- - An index into `gather_indices.rank` - `1` dimensions of `gather_indices`,
- which gives us a starting index of a slice, _operand slice_, in the operand
- tensor. These `gather_indices.rank` - `1` dimensions are all the dimensions
- in `gather_indices` except `index_vector_dim`.
-
- - A _window index_ that has the same rank as the operand. This index is
- composed of the values in `Out` at dimensions `output_window_dims`, embedded
- with zeroes according to `elided_window_dims`.
-
-The _window index_ is the relative index of the element in _operand slice_ that
-should be present in the output at index `Out`.
-
-The output is a tensor of rank `output_window_dims.size` + `gather_indices.rank`
-- `1`. Additionally, as a shorthand, we define `output_gather_dims` of type
-`ArraySlice<int64>` as the set of dimensions in the output shape but not in
-`output_window_dims`, in ascending order. E.g. if the output tensor has rank
-`5`, `output_window_dims` is {`2`, `4`} then `output_gather_dims` is {`0`, `1`,
-`3`}
-
-If `index_vector_dim` is equal to `gather_indices.rank` we implicitly
-consider `gather_indices` to have a trailing `1` dimension (i.e. if
-`gather_indices` was of shape `[6,7]` and `index_vector_dim` is `2` then
-we implicitly consider the shape of `gather_indices` to be `[6,7,1]`).
-
-The bounds for the output tensor along dimension `i` is computed as follows:
-
- 1. If `i` is present in `output_gather_dims` (i.e. is equal to
- `output_gather_dims[k]` for some `k`) then we pick the corresponding
- dimension bounds out of `gather_indices.shape`, skipping
- `index_vector_dim` (i.e. pick `gather_indices.shape.dims`[`k`] if `k`
- < `index_vector_dim` and `gather_indices.shape.dims`[`k`+`1`]
- otherwise).
- 2. If `i` is present in `output_window_dims` (i.e. equal to
- `output_window_dims`[`k`] for some `k`) then we pick the corresponding
- bound out of `window_bounds` after accounting for `elided_window_dims`
- (i.e. we pick `adjusted_window_bounds`[`k`] where `adjusted_window_bounds`
- is `window_bounds` with the bounds at indices `elided_window_dims`
- removed).
-
-The operand index `In` corresponding to an output index `Out` is computed as
-follows:
-
- 1. Let `G` = { `Out`[`k`] for `k` in `output_gather_dims` }. Use `G` to slice
- out vector `S` such that `S`[`i`] = `gather_indices`[Combine(`G`, `i`)]
- where Combine(A, b) inserts b at position `index_vector_dim` into A.
- Note that this is well defined even if `G` is empty -- if `G` is empty then
- `S` = `gather_indices`.
- 2. Create an index, `S`<sub>`in`</sub>, into `operand` using `S` by
- scattering `S` using the `gather_dims_to_operand_dims` map
- (`S`<sub>`in`</sub> is the starting indices for _operand slice_ mentioned
- above). More precisely:
- 1. `S`<sub>`in`</sub>[`gather_dims_to_operand_dims`[`k`]] = `S`[`k`] if `k` <
- `gather_dims_to_operand_dims.size`.
+|`start_indices` | `XlaOp` | Array containing the starting |
+: : : indices of the slices we gather.:
+|`index_vector_dim` | `int64` | The dimension in |
+: : : `start_indices` that "contains" :
+: : : the starting indices. See :
+: : : below for a detailed :
+: : : description. :
+|`offset_dims` | `ArraySlice<int64>` | The set of dimensions in the :
+: : : output shape that offset into a :
+: : : array sliced from operand. :
+|`slice_sizes` | `ArraySlice<int64>` | `slice_sizes[i]` is the bounds |
+: : : for the slice on dimension `i`.:
+|`collapsed_slice_dims` | `ArraySlice<int64>` | The set of dimensions in each :
+| : | slice that are collapsed away. :
+| : | These dimensions must have size:
+| : | 1. |
+|`start_index_map` | `ArraySlice<int64>` | A map that describes how to map|
+: : : indices in `start_indices` to :
+: : : to legal indices into operand. :
+
+For convenience, we label dimensions in the output array not in `offset_dims`
+as `batch_dims`.
+
+The output is an array of rank `batch_dims.size` + `operand.rank` -
+`collapsed_slice_dims`.size.
+
+If `index_vector_dim` is equal to `start_indices.rank` we implicitly consider
+`start_indices` to have a trailing `1` dimension (i.e. if `start_indices` was of
+shape `[6,7]` and `index_vector_dim` is `2` then we implicitly consider the
+shape of `start_indices` to be `[6,7,1]`).
+
+The bounds for the output array along dimension `i` is computed as follows:
+
+ 1. If `i` is present in `batch_dims` (i.e. is equal to `batch_dims[k]` for
+ some `k`) then we pick the corresponding dimension bounds out of
+ `start_indices.shape`, skipping `index_vector_dim` (i.e. pick
+ `start_indices.shape.dims`[`k`] if `k` < `index_vector_dim` and
+ `start_indices.shape.dims`[`k`+`1`] otherwise).
+
+ 2. If `i` is present in `offset_dims` (i.e. equal to `offset_dims`[`k`] for
+ some `k`) then we pick the corresponding bound out of `slice_sizes` after
+ accounting for `collapsed_slice_dims` (i.e. we pick
+ `adjusted_slice_sizes`[`k`] where `adjusted_slice_sizes` is `slice_sizes`
+ with the bounds at indices `collapsed_slice_dims` removed).
+
+Formally, the operand index `In` corresponding to an output index `Out` is
+computed as follows:
+
+ 1. Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out
+ vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where
+ Combine(A, b) inserts b at position `index_vector_dim` into A. Note that
+ this is well defined even if `G` is empty -- if `G` is empty then `S` =
+ `start_indices`.
+
+ 2. Create a starting index, `S`<sub>`in`</sub>, into `operand` using `S` by
+ scattering `S` using `start_index_map`. More precisely:
+ 1. `S`<sub>`in`</sub>[`start_index_map`[`k`]] = `S`[`k`] if `k` <
+ `start_index_map.size`.
2. `S`<sub>`in`</sub>[`_`] = `0` otherwise.
- 3. Create an index `W`<sub>`in`</sub> into `operand` by scattering the indices
- at the output window dimensions in `Out` according to
- the `elided_window_dims` set (`W`<sub>`in`</sub> is the _window index_
- mentioned above). More precisely:
- 1. `W`<sub>`in`</sub>[`window_dims_to_operand_dims`(`k`)] = `Out`[`k`] if
- `k` < `output_window_dims.size` (`window_dims_to_operand_dims` is
- defined below).
- 2. `W`<sub>`in`</sub>[`_`] = `0` otherwise.
- 4. `In` is `W`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise
+
+ 3. Create an index `O`<sub>`in`</sub> into `operand` by scattering the indices
+ at the offset dimensions in `Out` according to the `collapsed_slice_dims`
+ set. More precisely:
+ 1. `O`<sub>`in`</sub>[`expand_offset_dims`(`k`)] =
+ `Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size`
+ (`expand_offset_dims` is defined below).
+ 2. `O`<sub>`in`</sub>[`_`] = `0` otherwise.
+ 4. `In` is `O`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise
addition.
-`window_dims_to_operand_dims` is the monotonic function with domain [`0`,
-`output_window_dims.size`) and range [`0`, `operand.rank`) \
-`elided_window_dims`. So if, e.g., `output_window_dims.size` is `4`,
-`operand.rank` is `6` and `elided_window_dims` is {`0`, `2`} then
-`window_dims_to_operand_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}.
+`expand_offset_dims` is the monotonic function with domain [`0`, `offset.size`)
+and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g.,
+`offset.size` is `4`, `operand.rank` is `6` and `collapsed_slice_dims` is {`0`,
+`2`} then `expand_offset_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}.
### Informal Description and Examples
-`index_vector_dim` is set to `gather_indices.rank` - `1` in all of the
-examples that follow. More interesting values for `index_vector_dim`
-does not change the operation fundamentally, but makes the visual representation
-more cumbersome.
+Informally, every index `Out` in the output array corresponds to an element `E`
+in the operand array, computed as follows:
+
+ - We use the batch dimensions in `Out` to look up a starting index from
+ `start_indices`.
+
+ - We use `start_index_map` to map the starting index (which may have size less
+ than operand.rank) to a "full" starting index into operand.
+
+ - We dynamic-slice out a slice with size `slice_sizes` using the full starting
+ index.
+
+ - We reshape the slice by collapsing the `collapsed_slice_dims` dimensions.
+ Since all collapsed slice dimensions have to have bound 1 this reshape is
+ always legal.
+
+ - We use the offset dimensions in `Out` to index into this slice to get the
+ input element, `E`, corresponding to output index `Out`.
+
+`index_vector_dim` is set to `start_indices.rank` - `1` in all of the
+examples that follow. More interesting values for `index_vector_dim` does not
+change the operation fundamentally, but makes the visual representation more
+cumbersome.
To get an intuition on how all of the above fits together, let's look at an
-example that gathers 5 slices of shape `[8,6]` from a `[16,11]` tensor. The
-position of a slice into the `[16,11]` tensor can be represented as an index
+example that gathers 5 slices of shape `[8,6]` from a `[16,11]` array. The
+position of a slice into the `[16,11]` array can be represented as an index
vector of shape `S64[2]`, so the set of 5 positions can be represented as a
-`S64[5,2]` tensor.
+`S64[5,2]` array.
The behavior of the gather operation can then be depicted as an index
-transformation that takes [`G`,`W`<sub>`0`</sub>,`W`<sub>`1`</sub>], an index in
-the output shape, and maps it to an element in the input tensor in the following
+transformation that takes [`G`,`O`<sub>`0`</sub>,`O`<sub>`1`</sub>], an index in
+the output shape, and maps it to an element in the input array in the following
way:
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../../images/ops_xla_gather_0.svg">
</div>
-We first select an (`X`,`Y`) vector from the gather indices tensor using `G`.
-The element in the output tensor at index
-[`G`,`W`<sub>`0`</sub>,`W`<sub>`1`</sub>] is then the element in the input
-tensor at index [`X`+`W`<sub>`0`</sub>,`Y`+`W`<sub>`1`</sub>].
+We first select an (`X`,`Y`) vector from the gather indices array using `G`.
+The element in the output array at index
+[`G`,`O`<sub>`0`</sub>,`O`<sub>`1`</sub>] is then the element in the input
+array at index [`X`+`O`<sub>`0`</sub>,`Y`+`O`<sub>`1`</sub>].
-`window_bounds` is `[8,6]`, which decides the range of W<sub>`0`</sub> and
+`slice_sizes` is `[8,6]`, which decides the range of W<sub>`0`</sub> and
W<sub>`1`</sub>, and this in turn decides the bounds of the slice.
This gather operation acts as a batch dynamic slice with `G` as the batch
dimension.
The gather indices may be multidimensional. For instance, a more general
-version of the example above using a "gather indices" tensor of shape `[4,5,2]`
+version of the example above using a "gather indices" array of shape `[4,5,2]`
would translate indices like this:
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
@@ -1298,25 +1288,25 @@ would translate indices like this:
</div>
Again, this acts as a batch dynamic slice `G`<sub>`0`</sub> and
-`G`<sub>`1`</sub> as the batch dimensions. The window bounds are still `[8,6]`.
+`G`<sub>`1`</sub> as the batch dimensions. The slice size is still `[8,6]`.
The gather operation in XLA generalizes the informal semantics outlined above in
the following ways:
- 1. We can configure which dimensions in the output shape are the window
- dimensions (dimensions containing `W`<sub>`0`</sub>, `W`<sub>`1`</sub> in
- the last example). The output gather dimensions (dimensions containing
+ 1. We can configure which dimensions in the output shape are the offset
+ dimensions (dimensions containing `O`<sub>`0`</sub>, `O`<sub>`1`</sub> in
+ the last example). The output batch dimensions (dimensions containing
`G`<sub>`0`</sub>, `G`<sub>`1`</sub> in the last example) are defined to be
- the output dimensions that are not window dimensions.
+ the output dimensions that are not offset dimensions.
- 2. The number of output window dimensions explicitly present in the output
+ 2. The number of output offset dimensions explicitly present in the output
shape may be smaller than the input rank. These "missing" dimensions, which
- are listed explicitly as `elided_window_dims`, must have a window bound of
- `1`. Since they have a window bound of `1` the only valid index for them is
+ are listed explicitly as `collapsed_slice_dims`, must have a slice size of
+ `1`. Since they have a slice size of `1` the only valid index for them is
`0` and eliding them does not introduce ambiguity.
- 3. The slice extracted from the "Gather Indices" tensor ((`X`, `Y`) in the last
- example) may have fewer elements than the input tensor rank, and an explicit
+ 3. The slice extracted from the "Gather Indices" array ((`X`, `Y`) in the last
+ example) may have fewer elements than the input array rank, and an explicit
mapping dictates how the index should be expanded to have the same rank as
the input.
@@ -1327,20 +1317,19 @@ As a final example, we use (2) and (3) to implement `tf.gather_nd`:
</div>
`G`<sub>`0`</sub> and `G`<sub>`1`</sub> are used to slice out a starting index
-from the gather indices tensor as usual, except the starting index has only one
-element, `X`. Similarly, there is only one output window index with the value
-`W`<sub>`0`</sub>. However, before being used as indices into the input tensor,
-these are expanded in accordance to "Gather Index Mapping"
-(`gather_dims_to_operand_dims` in the formal description) and "Window Mapping"
-(`window_dims_to_operand_dims` in the formal description) into
-[`0`,`W`<sub>`0`</sub>] and [`X`,`0`] respectively, adding up to
-[`X`,`W`<sub>`0`</sub>]. In other words, the output index
-[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`W`<sub>`0`</sub>] maps to the input index
+from the gather indices array as usual, except the starting index has only one
+element, `X`. Similarly, there is only one output offset index with the value
+`O`<sub>`0`</sub>. However, before being used as indices into the input array,
+these are expanded in accordance to "Gather Index Mapping" (`start_index_map` in
+the formal description) and "Offset Mapping" (`expand_offset_dims` in the formal
+description) into [`0`,`O`<sub>`0`</sub>] and [`X`,`0`] respectively, adding up
+to [`X`,`O`<sub>`0`</sub>]. In other words, the output index
+[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`O`<sub>`0`</sub>] maps to the input index
[`GatherIndices`[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`0`],`X`] which gives us
the semantics for `tf.gather_nd`.
-`window_bounds` for this case is `[1,11]`. Intuitively this means that every
-index `X` in the gather indices tensor picks an entire row and the result is the
+`slice_sizes` for this case is `[1,11]`. Intuitively this means that every
+index `X` in the gather indices array picks an entire row and the result is the
concatenation of all these rows.
## GetTupleElement