diff options
author | Sanjoy Das <sanjoy@google.com> | 2018-08-16 14:44:15 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-16 14:47:49 -0700 |
commit | d43820b9eff0cc863de2bbfb142afe92bf5afd00 (patch) | |
tree | e019bdc0bb52496436e0ff92d76b728165e3a420 /tensorflow/compiler/xla/service/hlo_instructions.cc | |
parent | 8235e83c442744a1285ce97e5dfc2a6556f9f667 (diff) |
Improve gather ergonomics by renaming fields.
This CL renames the various inputs to the Gather HLO to be more mnemonic by
making it more obviously a batch dynamic-slice. The replacements I made are:
s/elided_window_dims/collapsed_slice_dims/g
s/window_bounds/slice_sizes/g
s/gather_dims_to_operand_dims/start_index_map/g
s/gather_indices/start_indices/g
s/output_window_dims/offset_dims/g
PiperOrigin-RevId: 209051067
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.cc | 57 |
1 files changed, 28 insertions, 29 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 233cdda7b0..4fdf4360e6 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1965,51 +1965,50 @@ HloDynamicSliceInstruction::CloneWithNewOperandsImpl( } HloGatherInstruction::HloGatherInstruction( - const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices, + const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice<int64> window_bounds) + tensorflow::gtl::ArraySlice<int64> slice_sizes) : HloInstruction(HloOpcode::kGather, shape) { AppendOperand(operand); - AppendOperand(gather_indices); + AppendOperand(start_indices); gather_dimension_numbers_ = MakeUnique<GatherDimensionNumbers>(gather_dim_numbers); - c_copy(window_bounds, std::back_inserter(gather_window_bounds_)); + c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_)); } string HloGatherInstruction::GatherDimensionNumbersToString() const { CHECK(gather_dimension_numbers_ != nullptr); - string output_window_dims = - StrCat("output_window_dims={", - Join(gather_dimension_numbers_->output_window_dims(), ","), "}"); - string elided_window_dims = - StrCat("elided_window_dims={", - Join(gather_dimension_numbers_->elided_window_dims(), ","), "}"); - string gather_dims_to_operand_dims = StrCat( - "gather_dims_to_operand_dims={", - Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}"); + string offset_dims = + StrCat("offset_dims={", + Join(gather_dimension_numbers_->offset_dims(), ","), "}"); + string collapsed_slice_dims = + StrCat("collapsed_slice_dims={", + Join(gather_dimension_numbers_->collapsed_slice_dims(), ","), "}"); + string start_index_map = + StrCat("start_index_map={", + Join(gather_dimension_numbers_->start_index_map(), ","), "}"); string index_vector_dim = StrCat( "index_vector_dim=", gather_dimension_numbers_->index_vector_dim()); return Join<std::initializer_list<string>>( - {output_window_dims, elided_window_dims, gather_dims_to_operand_dims, - index_vector_dim}, + {offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim}, ", "); } /* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers( - tensorflow::gtl::ArraySlice<int64> output_window_dims, - tensorflow::gtl::ArraySlice<int64> elided_window_dims, - tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims, + tensorflow::gtl::ArraySlice<int64> offset_dims, + tensorflow::gtl::ArraySlice<int64> collapsed_slice_dims, + tensorflow::gtl::ArraySlice<int64> start_index_map, int64 index_vector_dim) { GatherDimensionNumbers gather_dim_numbers; - for (int64 output_window_dim : output_window_dims) { - gather_dim_numbers.add_output_window_dims(output_window_dim); + for (int64 output_window_dim : offset_dims) { + gather_dim_numbers.add_offset_dims(output_window_dim); } - for (int64 elided_window_dim : elided_window_dims) { - gather_dim_numbers.add_elided_window_dims(elided_window_dim); + for (int64 elided_window_dim : collapsed_slice_dims) { + gather_dim_numbers.add_collapsed_slice_dims(elided_window_dim); } - for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) { - gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim); + for (int64 gather_dim_to_input_dim : start_index_map) { + gather_dim_numbers.add_start_index_map(gather_dim_to_input_dim); } gather_dim_numbers.set_index_vector_dim(index_vector_dim); @@ -2019,8 +2018,8 @@ string HloGatherInstruction::GatherDimensionNumbersToString() const { HloInstructionProto HloGatherInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); *proto.mutable_gather_dimension_numbers() = gather_dimension_numbers(); - for (int64 bound : gather_window_bounds()) { - proto.add_gather_window_bounds(bound); + for (int64 bound : gather_slice_sizes()) { + proto.add_gather_slice_sizes(bound); } return proto; } @@ -2028,7 +2027,7 @@ HloInstructionProto HloGatherInstruction::ToProto() const { std::vector<string> HloGatherInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { return {GatherDimensionNumbersToString(), - StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")}; + StrCat("slice_sizes={", Join(gather_slice_sizes(), ","), "}")}; } bool HloGatherInstruction::IdenticalSlowPath( @@ -2039,7 +2038,7 @@ bool HloGatherInstruction::IdenticalSlowPath( return protobuf_util::ProtobufEquals( gather_dimension_numbers(), casted_other.gather_dimension_numbers()) && - gather_window_bounds() == casted_other.gather_window_bounds(); + gather_slice_sizes() == casted_other.gather_slice_sizes(); } std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl( @@ -2049,7 +2048,7 @@ std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl( CHECK_EQ(new_operands.size(), 2); return MakeUnique<HloGatherInstruction>( shape, new_operands[0], new_operands[1], gather_dimension_numbers(), - gather_window_bounds()); + gather_slice_sizes()); } HloScatterInstruction::HloScatterInstruction( |