aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instructions.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-06 20:48:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 20:52:28 -0700
commit44efcf0db7b9204a77710a7f076c904d0e13e6fa (patch)
tree2114672b05239fcdb9554bcf26354a81b4f8565b /tensorflow/compiler/xla/service/hlo_instructions.cc
parenta7e3047fea74a43174c063320fd0cb6bb6dcceb1 (diff)
Split out HloDomainInstruction as subclass form HloInstruction.
PiperOrigin-RevId: 211916428
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc39
1 files changed, 39 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 4e3e0c055e..76712d73db 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -2224,4 +2224,43 @@ string HloDotInstruction::DotDimensionNumbersToString() const {
return StrJoin(result, ", ");
}
+
+HloDomainInstruction::HloDomainInstruction(
+ const Shape& shape, HloInstruction* operand,
+ std::unique_ptr<DomainMetadata> operand_side_metadata,
+ std::unique_ptr<DomainMetadata> user_side_metadata)
+ : HloInstruction(HloOpcode::kDomain, shape),
+ operand_side_metadata_(std::move(operand_side_metadata)),
+ user_side_metadata_(std::move(user_side_metadata)) {
+ AppendOperand(operand);
+}
+
+std::vector<string> HloDomainInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
+ return {StrCat("domain={kind=\"", operand_side_metadata_->Kind(),
+ "\", entry=", user_side_metadata_->ToString(),
+ ", exit=", operand_side_metadata_->ToString(), "}")};
+ }
+ return {};
+}
+
+bool HloDomainInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloDomainInstruction&>(other);
+ return operand_side_metadata().Matches(
+ casted_other.operand_side_metadata()) &&
+ user_side_metadata().Matches(casted_other.user_side_metadata());
+}
+
+std::unique_ptr<HloInstruction> HloDomainInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 1);
+ return absl::make_unique<HloDomainInstruction>(
+ shape, new_operands[0], operand_side_metadata_->Clone(),
+ user_side_metadata_->Clone());
+}
} // namespace xla