aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instructions.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-06 20:09:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 20:14:24 -0700
commitac8cf2ad5d01010b978c5b41c2fac22ee69a90c4 (patch)
tree06840591db9d2a077b28fe28f73baae913065550 /tensorflow/compiler/xla/service/hlo_instructions.cc
parent1cc48be8da90c2d5d3a2ebdf6ed46be623fa0c03 (diff)
Split out HloDotInstruction as subclass from HloInstruction.
PiperOrigin-RevId: 211912785
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc63
1 files changed, 63 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index ad87aa1123..4e3e0c055e 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -1663,6 +1663,7 @@ HloInstructionProto HloConvolutionInstruction::ToProto() const {
*proto.mutable_convolution_dimension_numbers() =
convolution_dimension_numbers_;
proto.set_feature_group_count(feature_group_count_);
+ *proto.mutable_precision_config() = precision_config();
return proto;
}
@@ -2161,4 +2162,66 @@ std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl(
return absl::make_unique<HloIotaInstruction>(shape, iota_dimension());
}
+HloDotInstruction::HloDotInstruction(
+ const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfig& precision_config)
+ : HloInstruction(HloOpcode::kDot, shape),
+ dot_dimension_numbers_(dimension_numbers) {
+ AppendOperand(lhs);
+ AppendOperand(rhs);
+ set_precision_config(precision_config);
+}
+
+HloInstructionProto HloDotInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ *proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_;
+ *proto.mutable_precision_config() = precision_config();
+ return proto;
+}
+
+std::vector<string> HloDotInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {DotDimensionNumbersToString()};
+}
+
+bool HloDotInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloDotInstruction&>(other);
+ return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
+ casted_other.dot_dimension_numbers());
+}
+
+std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 2);
+ return absl::make_unique<HloDotInstruction>(
+ shape, new_operands[0], new_operands[1], dot_dimension_numbers_,
+ precision_config());
+}
+
+string HloDotInstruction::DotDimensionNumbersToString() const {
+ std::vector<string> result;
+ const DotDimensionNumbers& dnums = dot_dimension_numbers_;
+ if (!dnums.lhs_batch_dimensions().empty()) {
+ result.push_back(StrCat("lhs_batch_dims={",
+ StrJoin(dnums.lhs_batch_dimensions(), ","), "}"));
+ }
+ result.push_back(StrCat("lhs_contracting_dims={",
+ StrJoin(dnums.lhs_contracting_dimensions(), ","),
+ "}"));
+
+ if (!dnums.rhs_batch_dimensions().empty()) {
+ result.push_back(StrCat("rhs_batch_dims={",
+ StrJoin(dnums.rhs_batch_dimensions(), ","), "}"));
+ }
+ result.push_back(StrCat("rhs_contracting_dims={",
+ StrJoin(dnums.rhs_contracting_dimensions(), ","),
+ "}"));
+
+ return StrJoin(result, ", ");
+}
} // namespace xla