diff options
author | 2018-09-06 20:09:38 -0700 | |
---|---|---|
committer | 2018-09-06 20:14:24 -0700 | |
commit | ac8cf2ad5d01010b978c5b41c2fac22ee69a90c4 (patch) | |
tree | 06840591db9d2a077b28fe28f73baae913065550 /tensorflow/compiler/xla/service/hlo_instructions.cc | |
parent | 1cc48be8da90c2d5d3a2ebdf6ed46be623fa0c03 (diff) |
Split out HloDotInstruction as subclass from HloInstruction.
PiperOrigin-RevId: 211912785
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.cc | 63 |
1 files changed, 63 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index ad87aa1123..4e3e0c055e 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1663,6 +1663,7 @@ HloInstructionProto HloConvolutionInstruction::ToProto() const { *proto.mutable_convolution_dimension_numbers() = convolution_dimension_numbers_; proto.set_feature_group_count(feature_group_count_); + *proto.mutable_precision_config() = precision_config(); return proto; } @@ -2161,4 +2162,66 @@ std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl( return absl::make_unique<HloIotaInstruction>(shape, iota_dimension()); } +HloDotInstruction::HloDotInstruction( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config) + : HloInstruction(HloOpcode::kDot, shape), + dot_dimension_numbers_(dimension_numbers) { + AppendOperand(lhs); + AppendOperand(rhs); + set_precision_config(precision_config); +} + +HloInstructionProto HloDotInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_; + *proto.mutable_precision_config() = precision_config(); + return proto; +} + +std::vector<string> HloDotInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {DotDimensionNumbersToString()}; +} + +bool HloDotInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const { + const auto& casted_other = static_cast<const HloDotInstruction&>(other); + return protobuf_util::ProtobufEquals(dot_dimension_numbers(), + casted_other.dot_dimension_numbers()); +} + +std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span<HloInstruction* const> new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return absl::make_unique<HloDotInstruction>( + shape, new_operands[0], new_operands[1], dot_dimension_numbers_, + precision_config()); +} + +string HloDotInstruction::DotDimensionNumbersToString() const { + std::vector<string> result; + const DotDimensionNumbers& dnums = dot_dimension_numbers_; + if (!dnums.lhs_batch_dimensions().empty()) { + result.push_back(StrCat("lhs_batch_dims={", + StrJoin(dnums.lhs_batch_dimensions(), ","), "}")); + } + result.push_back(StrCat("lhs_contracting_dims={", + StrJoin(dnums.lhs_contracting_dimensions(), ","), + "}")); + + if (!dnums.rhs_batch_dimensions().empty()) { + result.push_back(StrCat("rhs_batch_dims={", + StrJoin(dnums.rhs_batch_dimensions(), ","), "}")); + } + result.push_back(StrCat("rhs_contracting_dims={", + StrJoin(dnums.rhs_contracting_dimensions(), ","), + "}")); + + return StrJoin(result, ", "); +} } // namespace xla |